mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-27 23:00:20 +08:00
Merge branch 'comfyanonymous:master' into master
This commit is contained in:
commit
49fa16cc7a
@ -188,6 +188,12 @@ def is_nvidia():
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def is_amd():
|
||||||
|
global cpu_state
|
||||||
|
if cpu_state == CPUState.GPU:
|
||||||
|
if torch.version.hip:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
MIN_WEIGHT_MEMORY_RATIO = 0.4
|
MIN_WEIGHT_MEMORY_RATIO = 0.4
|
||||||
if is_nvidia():
|
if is_nvidia():
|
||||||
@ -198,27 +204,17 @@ if args.use_pytorch_cross_attention:
|
|||||||
ENABLE_PYTORCH_ATTENTION = True
|
ENABLE_PYTORCH_ATTENTION = True
|
||||||
XFORMERS_IS_AVAILABLE = False
|
XFORMERS_IS_AVAILABLE = False
|
||||||
|
|
||||||
VAE_DTYPES = [torch.float32]
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if is_nvidia():
|
if is_nvidia():
|
||||||
if int(torch_version[0]) >= 2:
|
if int(torch_version[0]) >= 2:
|
||||||
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||||
ENABLE_PYTORCH_ATTENTION = True
|
ENABLE_PYTORCH_ATTENTION = True
|
||||||
if torch.cuda.is_bf16_supported() and torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 8:
|
|
||||||
VAE_DTYPES = [torch.bfloat16] + VAE_DTYPES
|
|
||||||
if is_intel_xpu():
|
if is_intel_xpu():
|
||||||
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||||
ENABLE_PYTORCH_ATTENTION = True
|
ENABLE_PYTORCH_ATTENTION = True
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if is_intel_xpu():
|
|
||||||
VAE_DTYPES = [torch.bfloat16] + VAE_DTYPES
|
|
||||||
|
|
||||||
if args.cpu_vae:
|
|
||||||
VAE_DTYPES = [torch.float32]
|
|
||||||
|
|
||||||
if ENABLE_PYTORCH_ATTENTION:
|
if ENABLE_PYTORCH_ATTENTION:
|
||||||
torch.backends.cuda.enable_math_sdp(True)
|
torch.backends.cuda.enable_math_sdp(True)
|
||||||
torch.backends.cuda.enable_flash_sdp(True)
|
torch.backends.cuda.enable_flash_sdp(True)
|
||||||
@ -767,7 +763,6 @@ def vae_offload_device():
|
|||||||
return torch.device("cpu")
|
return torch.device("cpu")
|
||||||
|
|
||||||
def vae_dtype(device=None, allowed_dtypes=[]):
|
def vae_dtype(device=None, allowed_dtypes=[]):
|
||||||
global VAE_DTYPES
|
|
||||||
if args.fp16_vae:
|
if args.fp16_vae:
|
||||||
return torch.float16
|
return torch.float16
|
||||||
elif args.bf16_vae:
|
elif args.bf16_vae:
|
||||||
@ -776,12 +771,14 @@ def vae_dtype(device=None, allowed_dtypes=[]):
|
|||||||
return torch.float32
|
return torch.float32
|
||||||
|
|
||||||
for d in allowed_dtypes:
|
for d in allowed_dtypes:
|
||||||
if d == torch.float16 and should_use_fp16(device, prioritize_performance=False):
|
if d == torch.float16 and should_use_fp16(device):
|
||||||
return d
|
|
||||||
if d in VAE_DTYPES:
|
|
||||||
return d
|
return d
|
||||||
|
|
||||||
return VAE_DTYPES[0]
|
# NOTE: bfloat16 seems to work on AMD for the VAE but is extremely slow in some cases compared to fp32
|
||||||
|
if d == torch.bfloat16 and (not is_amd()) and should_use_bf16(device):
|
||||||
|
return d
|
||||||
|
|
||||||
|
return torch.float32
|
||||||
|
|
||||||
def get_autocast_device(dev):
|
def get_autocast_device(dev):
|
||||||
if hasattr(dev, 'type'):
|
if hasattr(dev, 'type'):
|
||||||
@ -902,14 +899,19 @@ def pytorch_attention_flash_attention():
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def mac_version():
|
||||||
|
try:
|
||||||
|
return tuple(int(n) for n in platform.mac_ver()[0].split("."))
|
||||||
|
except:
|
||||||
|
return None
|
||||||
|
|
||||||
def force_upcast_attention_dtype():
|
def force_upcast_attention_dtype():
|
||||||
upcast = args.force_upcast_attention
|
upcast = args.force_upcast_attention
|
||||||
try:
|
|
||||||
macos_version = tuple(int(n) for n in platform.mac_ver()[0].split("."))
|
macos_version = mac_version()
|
||||||
if (14, 5) <= macos_version <= (15, 2): # black image bug on recent versions of macOS
|
if macos_version is not None and ((14, 5) <= macos_version <= (15, 2)): # black image bug on recent versions of macOS
|
||||||
upcast = True
|
upcast = True
|
||||||
except:
|
|
||||||
pass
|
|
||||||
if upcast:
|
if upcast:
|
||||||
return torch.float32
|
return torch.float32
|
||||||
else:
|
else:
|
||||||
@ -980,17 +982,13 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
|
|||||||
if FORCE_FP16:
|
if FORCE_FP16:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if device is not None:
|
|
||||||
if is_device_mps(device):
|
|
||||||
return True
|
|
||||||
|
|
||||||
if FORCE_FP32:
|
if FORCE_FP32:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if directml_enabled:
|
if directml_enabled:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if mps_mode():
|
if (device is not None and is_device_mps(device)) or mps_mode():
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if cpu_mode():
|
if cpu_mode():
|
||||||
@ -1039,17 +1037,15 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
|||||||
if is_device_cpu(device): #TODO ? bf16 works on CPU but is extremely slow
|
if is_device_cpu(device): #TODO ? bf16 works on CPU but is extremely slow
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if device is not None:
|
|
||||||
if is_device_mps(device):
|
|
||||||
return True
|
|
||||||
|
|
||||||
if FORCE_FP32:
|
if FORCE_FP32:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if directml_enabled:
|
if directml_enabled:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if mps_mode():
|
if (device is not None and is_device_mps(device)) or mps_mode():
|
||||||
|
if mac_version() < (14,):
|
||||||
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if cpu_mode():
|
if cpu_mode():
|
||||||
|
|||||||
18
comfy/ops.py
18
comfy/ops.py
@ -255,9 +255,10 @@ def fp8_linear(self, input):
|
|||||||
tensor_2d = True
|
tensor_2d = True
|
||||||
input = input.unsqueeze(1)
|
input = input.unsqueeze(1)
|
||||||
|
|
||||||
|
input_shape = input.shape
|
||||||
|
input_dtype = input.dtype
|
||||||
if len(input.shape) == 3:
|
if len(input.shape) == 3:
|
||||||
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input.dtype)
|
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype)
|
||||||
w = w.t()
|
w = w.t()
|
||||||
|
|
||||||
scale_weight = self.scale_weight
|
scale_weight = self.scale_weight
|
||||||
@ -269,23 +270,24 @@ def fp8_linear(self, input):
|
|||||||
|
|
||||||
if scale_input is None:
|
if scale_input is None:
|
||||||
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
||||||
inn = torch.clamp(input, min=-448, max=448).reshape(-1, input.shape[2]).to(dtype)
|
input = torch.clamp(input, min=-448, max=448, out=input)
|
||||||
|
input = input.reshape(-1, input_shape[2]).to(dtype)
|
||||||
else:
|
else:
|
||||||
scale_input = scale_input.to(input.device)
|
scale_input = scale_input.to(input.device)
|
||||||
inn = (input * (1.0 / scale_input).to(input.dtype)).reshape(-1, input.shape[2]).to(dtype)
|
input = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype)
|
||||||
|
|
||||||
if bias is not None:
|
if bias is not None:
|
||||||
o = torch._scaled_mm(inn, w, out_dtype=input.dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
|
o = torch._scaled_mm(input, w, out_dtype=input_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
|
||||||
else:
|
else:
|
||||||
o = torch._scaled_mm(inn, w, out_dtype=input.dtype, scale_a=scale_input, scale_b=scale_weight)
|
o = torch._scaled_mm(input, w, out_dtype=input_dtype, scale_a=scale_input, scale_b=scale_weight)
|
||||||
|
|
||||||
if isinstance(o, tuple):
|
if isinstance(o, tuple):
|
||||||
o = o[0]
|
o = o[0]
|
||||||
|
|
||||||
if tensor_2d:
|
if tensor_2d:
|
||||||
return o.reshape(input.shape[0], -1)
|
return o.reshape(input_shape[0], -1)
|
||||||
|
|
||||||
return o.reshape((-1, input.shape[1], self.weight.shape[0]))
|
return o.reshape((-1, input_shape[1], self.weight.shape[0]))
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@ -111,7 +111,7 @@ class CLIP:
|
|||||||
model_management.load_models_gpu([self.patcher], force_full_load=True)
|
model_management.load_models_gpu([self.patcher], force_full_load=True)
|
||||||
self.layer_idx = None
|
self.layer_idx = None
|
||||||
self.use_clip_schedule = False
|
self.use_clip_schedule = False
|
||||||
logging.debug("CLIP model load device: {}, offload device: {}, current: {}".format(load_device, offload_device, params['device']))
|
logging.info("CLIP model load device: {}, offload device: {}, current: {}, dtype: {}".format(load_device, offload_device, params['device'], dtype))
|
||||||
|
|
||||||
def clone(self):
|
def clone(self):
|
||||||
n = CLIP(no_init=True)
|
n = CLIP(no_init=True)
|
||||||
@ -402,7 +402,7 @@ class VAE:
|
|||||||
self.output_device = model_management.intermediate_device()
|
self.output_device = model_management.intermediate_device()
|
||||||
|
|
||||||
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
|
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
|
||||||
logging.debug("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
|
logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
|
||||||
|
|
||||||
def vae_encode_crop_pixels(self, pixels):
|
def vae_encode_crop_pixels(self, pixels):
|
||||||
downscale_ratio = self.spacial_compression_encode()
|
downscale_ratio = self.spacial_compression_encode()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user