diff --git a/comfy/model_management.py b/comfy/model_management.py index 94b335620..c4f44f8ef 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -575,7 +575,9 @@ def loaded_models(only_currently_used=False): def cleanup_models(keep_clone_weights_loaded=False): to_delete = [] for i in range(len(current_loaded_models)): - if sys.getrefcount(current_loaded_models[i].model) <= 2: + #TODO: very fragile function needs improvement + num_refs = sys.getrefcount(current_loaded_models[i].model) - current_loaded_models[i].model.lowvram_patch_counter() + if num_refs <= 2: if not keep_clone_weights_loaded: to_delete = [i] + to_delete #TODO: find a less fragile way to do this. @@ -898,7 +900,8 @@ def pytorch_attention_flash_attention(): def force_upcast_attention_dtype(): upcast = args.force_upcast_attention try: - if platform.mac_ver()[0] in ['14.5']: #black image bug on OSX Sonoma 14.5 + macos_version = tuple(int(n) for n in platform.mac_ver()[0].split(".")) + if (14, 5) <= macos_version < (14, 7): # black image bug on recent versions of MacOS upcast = True except: pass diff --git a/comfy/ops.py b/comfy/ops.py index 5fef7cee7..d7596634e 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -20,31 +20,36 @@ import torch import comfy.model_management from comfy.cli_args import args -def cast_to(weight, dtype=None, device=None, non_blocking=False): - if (dtype is None or weight.dtype == dtype) and (device is None or weight.device == device): +def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=True): + if not copy and (dtype is None or weight.dtype == dtype) and (device is None or weight.device == device): return weight r = torch.empty_like(weight, dtype=dtype, device=device) r.copy_(weight, non_blocking=non_blocking) return r -def cast_to_input(weight, input, non_blocking=False): - return cast_to(weight, input.dtype, input.device, non_blocking=non_blocking) +def cast_to_input(weight, input, non_blocking=False, copy=True): + return cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy) -def cast_bias_weight(s, input=None, dtype=None, device=None): +def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None): if input is not None: if dtype is None: dtype = input.dtype + if bias_dtype is None: + bias_dtype = dtype if device is None: device = input.device bias = None non_blocking = comfy.model_management.device_supports_non_blocking(device) if s.bias is not None: - bias = cast_to(s.bias, dtype, device, non_blocking=non_blocking) - if s.bias_function is not None: + has_function = s.bias_function is not None + bias = cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function) + if has_function: bias = s.bias_function(bias) - weight = cast_to(s.weight, dtype, device, non_blocking=non_blocking) - if s.weight_function is not None: + + has_function = s.weight_function is not None + weight = cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function) + if has_function: weight = s.weight_function(weight) return weight, bias @@ -252,7 +257,8 @@ def fp8_linear(self, input): if len(input.shape) == 3: inn = input.reshape(-1, input.shape[2]).to(dtype) non_blocking = comfy.model_management.device_supports_non_blocking(input.device) - w = cast_to(self.weight, device=input.device, non_blocking=non_blocking).t() + w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input.dtype) + w = w.t() scale_weight = self.scale_weight scale_input = self.scale_input @@ -263,8 +269,8 @@ def fp8_linear(self, input): if scale_input is None: scale_input = torch.ones((1), device=input.device, dtype=torch.float32) - if self.bias is not None: - o = torch._scaled_mm(inn, w, out_dtype=input.dtype, bias=cast_to_input(self.bias, input, non_blocking=non_blocking), scale_a=scale_input, scale_b=scale_weight) + if bias is not None: + o = torch._scaled_mm(inn, w, out_dtype=input.dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight) else: o = torch._scaled_mm(inn, w, out_dtype=input.dtype, scale_a=scale_input, scale_b=scale_weight) diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 98aa8056b..3603313fa 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -654,6 +654,7 @@ class Flux(supported_models_base.BASE): def clip_target(self, state_dict={}): pref = self.text_encoder_key_prefix[0] t5_key = "{}t5xxl.transformer.encoder.final_layer_norm.weight".format(pref) + dtype_t5 = None if t5_key in state_dict: dtype_t5 = state_dict[t5_key].dtype return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.flux_clip(dtype_t5=dtype_t5))