From b1242568178c922c3417938be5a3259d1c395da1 Mon Sep 17 00:00:00 2001 From: HishamC <140008308+hisham-hchowdhu@users.noreply.github.com> Date: Tue, 11 Feb 2025 14:11:32 -0800 Subject: [PATCH 1/5] Fix for running via DirectML (#6542) * Fix for running via DirectML Fix DirectML empty image generation issue with Flux1. add CPU fallback for unsupported path. Verified the model works on AMD GPUs * fix formating * update casual mask calculation --- comfy/clip_model.py | 6 +++++- comfy/ldm/flux/math.py | 2 +- comfy/model_management.py | 7 +++++++ 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/comfy/clip_model.py b/comfy/clip_model.py index c48576028..0163c6fe7 100644 --- a/comfy/clip_model.py +++ b/comfy/clip_model.py @@ -104,7 +104,11 @@ class CLIPTextModel_(torch.nn.Module): mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]) mask = mask.masked_fill(mask.to(torch.bool), -torch.finfo(x.dtype).max) - causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(-torch.finfo(x.dtype).max).triu_(1) + if comfy.model_management.is_directml_enabled(): + causal_mask = torch.full((x.shape[1], x.shape[1]), -torch.finfo(x.dtype).max, dtype=x.dtype, device=x.device).triu_(1) + else: + causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1) + if mask is not None: mask += causal_mask else: diff --git a/comfy/ldm/flux/math.py b/comfy/ldm/flux/math.py index b5960ffd3..36b67931c 100644 --- a/comfy/ldm/flux/math.py +++ b/comfy/ldm/flux/math.py @@ -22,7 +22,7 @@ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor: def rope(pos: Tensor, dim: int, theta: int) -> Tensor: assert dim % 2 == 0 - if comfy.model_management.is_device_mps(pos.device) or comfy.model_management.is_intel_xpu(): + if comfy.model_management.is_device_mps(pos.device) or comfy.model_management.is_intel_xpu() or comfy.model_management.is_directml_enabled(): device = torch.device("cpu") else: device = pos.device diff --git a/comfy/model_management.py b/comfy/model_management.py index 28083fbf9..29cd43b51 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -991,6 +991,13 @@ def is_device_mps(device): def is_device_cuda(device): return is_device_type(device, 'cuda') +def is_directml_enabled(): + global directml_enabled + if directml_enabled: + return True + + return False + def should_use_fp16(device=None, model_params=0, prioritize_performance=True, manual_cast=False): global directml_enabled From d9f0fcdb0cdfd8f6fd0ec2ee14ea332bb87fd504 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 11 Feb 2025 17:17:03 -0500 Subject: [PATCH 2/5] Cleanup. --- comfy/clip_model.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/comfy/clip_model.py b/comfy/clip_model.py index 0163c6fe7..cf5b58b62 100644 --- a/comfy/clip_model.py +++ b/comfy/clip_model.py @@ -104,10 +104,7 @@ class CLIPTextModel_(torch.nn.Module): mask = 1.0 - attention_mask.to(x.dtype).reshape((attention_mask.shape[0], 1, -1, attention_mask.shape[-1])).expand(attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]) mask = mask.masked_fill(mask.to(torch.bool), -torch.finfo(x.dtype).max) - if comfy.model_management.is_directml_enabled(): - causal_mask = torch.full((x.shape[1], x.shape[1]), -torch.finfo(x.dtype).max, dtype=x.dtype, device=x.device).triu_(1) - else: - causal_mask = torch.empty(x.shape[1], x.shape[1], dtype=x.dtype, device=x.device).fill_(float("-inf")).triu_(1) + causal_mask = torch.full((x.shape[1], x.shape[1]), -torch.finfo(x.dtype).max, dtype=x.dtype, device=x.device).triu_(1) if mask is not None: mask += causal_mask From ab888e1e0b8a7558081713241172d0a38f837e16 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 12 Feb 2025 05:49:00 -0500 Subject: [PATCH 3/5] Add add_weight_wrapper function to model patcher. Functions can now easily be added to wrap/modify model weights. --- comfy/model_patcher.py | 61 ++++++++++++++++++++++++++++++++++-------- comfy/ops.py | 33 ++++++++++++----------- 2 files changed, 67 insertions(+), 27 deletions(-) diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 0501f7b38..aee0164c5 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -96,8 +96,28 @@ def wipe_lowvram_weight(m): if hasattr(m, "prev_comfy_cast_weights"): m.comfy_cast_weights = m.prev_comfy_cast_weights del m.prev_comfy_cast_weights - m.weight_function = None - m.bias_function = None + + if hasattr(m, "weight_function"): + m.weight_function = [] + + if hasattr(m, "bias_function"): + m.bias_function = [] + +def move_weight_functions(m, device): + if device is None: + return 0 + + memory = 0 + if hasattr(m, "weight_function"): + for f in m.weight_function: + if hasattr(f, "move_to"): + memory += f.move_to(device=device) + + if hasattr(m, "bias_function"): + for f in m.bias_function: + if hasattr(f, "move_to"): + memory += f.move_to(device=device) + return memory class LowVramPatch: def __init__(self, key, patches): @@ -192,6 +212,7 @@ class ModelPatcher: self.backup = {} self.object_patches = {} self.object_patches_backup = {} + self.weight_wrapper_patches = {} self.model_options = {"transformer_options":{}} self.model_size() self.load_device = load_device @@ -250,6 +271,7 @@ class ModelPatcher: n.patches_uuid = self.patches_uuid n.object_patches = self.object_patches.copy() + n.weight_wrapper_patches = self.weight_wrapper_patches.copy() n.model_options = copy.deepcopy(self.model_options) n.backup = self.backup n.object_patches_backup = self.object_patches_backup @@ -402,6 +424,10 @@ class ModelPatcher: def add_object_patch(self, name, obj): self.object_patches[name] = obj + def add_weight_wrapper(self, name, function): + self.weight_wrapper_patches[name] = self.weight_wrapper_patches.get(name, []) + [function] + self.patches_uuid = uuid.uuid4() + def get_model_object(self, name: str) -> torch.nn.Module: """Retrieves a nested attribute from an object using dot notation considering object patches. @@ -566,6 +592,9 @@ class ModelPatcher: lowvram_weight = False + weight_key = "{}.weight".format(n) + bias_key = "{}.bias".format(n) + if not full_load and hasattr(m, "comfy_cast_weights"): if mem_counter + module_mem >= lowvram_model_memory: lowvram_weight = True @@ -573,34 +602,42 @@ class ModelPatcher: if hasattr(m, "prev_comfy_cast_weights"): #Already lowvramed continue - weight_key = "{}.weight".format(n) - bias_key = "{}.bias".format(n) - if lowvram_weight: + if hasattr(m, "comfy_cast_weights"): + m.weight_function = [] + m.bias_function = [] + if weight_key in self.patches: if force_patch_weights: self.patch_weight_to_device(weight_key) else: - m.weight_function = LowVramPatch(weight_key, self.patches) + m.weight_function = [LowVramPatch(weight_key, self.patches)] patch_counter += 1 if bias_key in self.patches: if force_patch_weights: self.patch_weight_to_device(bias_key) else: - m.bias_function = LowVramPatch(bias_key, self.patches) + m.bias_function = [LowVramPatch(bias_key, self.patches)] patch_counter += 1 m.prev_comfy_cast_weights = m.comfy_cast_weights m.comfy_cast_weights = True else: if hasattr(m, "comfy_cast_weights"): - if m.comfy_cast_weights: - wipe_lowvram_weight(m) + wipe_lowvram_weight(m) if full_load or mem_counter + module_mem < lowvram_model_memory: mem_counter += module_mem load_completely.append((module_mem, n, m, params)) + if weight_key in self.weight_wrapper_patches: + m.weight_function.extend(self.weight_wrapper_patches[weight_key]) + + if bias_key in self.weight_wrapper_patches: + m.bias_function.extend(self.weight_wrapper_patches[bias_key]) + + mem_counter += move_weight_functions(m, device_to) + load_completely.sort(reverse=True) for x in load_completely: n = x[1] @@ -662,6 +699,7 @@ class ModelPatcher: self.unpatch_hooks() if self.model.model_lowvram: for m in self.model.modules(): + move_weight_functions(m, device_to) wipe_lowvram_weight(m) self.model.model_lowvram = False @@ -729,12 +767,13 @@ class ModelPatcher: bias_key = "{}.bias".format(n) if move_weight: m.to(device_to) + module_mem += move_weight_functions(m, device_to) if lowvram_possible: if weight_key in self.patches: - m.weight_function = LowVramPatch(weight_key, self.patches) + m.weight_function.append(LowVramPatch(weight_key, self.patches)) patch_counter += 1 if bias_key in self.patches: - m.bias_function = LowVramPatch(bias_key, self.patches) + m.bias_function.append(LowVramPatch(bias_key, self.patches)) patch_counter += 1 m.prev_comfy_cast_weights = m.comfy_cast_weights diff --git a/comfy/ops.py b/comfy/ops.py index 06be6b48b..30014477e 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -38,21 +38,23 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None): bias = None non_blocking = comfy.model_management.device_supports_non_blocking(device) if s.bias is not None: - has_function = s.bias_function is not None + has_function = len(s.bias_function) > 0 bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function) if has_function: - bias = s.bias_function(bias) + for f in s.bias_function: + bias = f(bias) - has_function = s.weight_function is not None + has_function = len(s.weight_function) > 0 weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function) if has_function: - weight = s.weight_function(weight) + for f in s.weight_function: + weight = f(weight) return weight, bias class CastWeightBiasOp: comfy_cast_weights = False - weight_function = None - bias_function = None + weight_function = [] + bias_function = [] class disable_weight_init: class Linear(torch.nn.Linear, CastWeightBiasOp): @@ -64,7 +66,7 @@ class disable_weight_init: return torch.nn.functional.linear(input, weight, bias) def forward(self, *args, **kwargs): - if self.comfy_cast_weights: + if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: return self.forward_comfy_cast_weights(*args, **kwargs) else: return super().forward(*args, **kwargs) @@ -78,7 +80,7 @@ class disable_weight_init: return self._conv_forward(input, weight, bias) def forward(self, *args, **kwargs): - if self.comfy_cast_weights: + if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: return self.forward_comfy_cast_weights(*args, **kwargs) else: return super().forward(*args, **kwargs) @@ -92,7 +94,7 @@ class disable_weight_init: return self._conv_forward(input, weight, bias) def forward(self, *args, **kwargs): - if self.comfy_cast_weights: + if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: return self.forward_comfy_cast_weights(*args, **kwargs) else: return super().forward(*args, **kwargs) @@ -106,7 +108,7 @@ class disable_weight_init: return self._conv_forward(input, weight, bias) def forward(self, *args, **kwargs): - if self.comfy_cast_weights: + if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: return self.forward_comfy_cast_weights(*args, **kwargs) else: return super().forward(*args, **kwargs) @@ -120,12 +122,11 @@ class disable_weight_init: return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps) def forward(self, *args, **kwargs): - if self.comfy_cast_weights: + if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: return self.forward_comfy_cast_weights(*args, **kwargs) else: return super().forward(*args, **kwargs) - class LayerNorm(torch.nn.LayerNorm, CastWeightBiasOp): def reset_parameters(self): return None @@ -139,7 +140,7 @@ class disable_weight_init: return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps) def forward(self, *args, **kwargs): - if self.comfy_cast_weights: + if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: return self.forward_comfy_cast_weights(*args, **kwargs) else: return super().forward(*args, **kwargs) @@ -160,7 +161,7 @@ class disable_weight_init: output_padding, self.groups, self.dilation) def forward(self, *args, **kwargs): - if self.comfy_cast_weights: + if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: return self.forward_comfy_cast_weights(*args, **kwargs) else: return super().forward(*args, **kwargs) @@ -181,7 +182,7 @@ class disable_weight_init: output_padding, self.groups, self.dilation) def forward(self, *args, **kwargs): - if self.comfy_cast_weights: + if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: return self.forward_comfy_cast_weights(*args, **kwargs) else: return super().forward(*args, **kwargs) @@ -199,7 +200,7 @@ class disable_weight_init: return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype) def forward(self, *args, **kwargs): - if self.comfy_cast_weights: + if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: return self.forward_comfy_cast_weights(*args, **kwargs) else: if "out_dtype" in kwargs: From 35740259de2798cf55098c231b7dab19f15e14da Mon Sep 17 00:00:00 2001 From: zhoufan2956 <78578838+zhoufan2956@users.noreply.github.com> Date: Wed, 12 Feb 2025 19:48:11 +0800 Subject: [PATCH 4/5] mix_ascend_bf16_infer_err (#6794) --- comfy/model_management.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/comfy/model_management.py b/comfy/model_management.py index 29cd43b51..cbf4c4ea6 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1082,6 +1082,9 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma if is_intel_xpu(): return True + + if is_ascend_npu(): + return True props = torch.cuda.get_device_properties(device) if props.major >= 8: From 1d5d6586f300fa54802682802020aafb61bd31a3 Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Wed, 12 Feb 2025 06:49:16 -0500 Subject: [PATCH 5/5] Fix ruff. --- comfy/model_management.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index cbf4c4ea6..f3d90c668 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1082,7 +1082,7 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma if is_intel_xpu(): return True - + if is_ascend_npu(): return True