Merge branch 'comfyanonymous:master' into master

This commit is contained in:
patientx 2024-12-25 14:05:18 +03:00 committed by GitHub
commit 49fa16cc7a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 39 additions and 41 deletions

View File

@ -188,6 +188,12 @@ def is_nvidia():
return True return True
return False return False
def is_amd():
global cpu_state
if cpu_state == CPUState.GPU:
if torch.version.hip:
return True
return False
MIN_WEIGHT_MEMORY_RATIO = 0.4 MIN_WEIGHT_MEMORY_RATIO = 0.4
if is_nvidia(): if is_nvidia():
@ -198,27 +204,17 @@ if args.use_pytorch_cross_attention:
ENABLE_PYTORCH_ATTENTION = True ENABLE_PYTORCH_ATTENTION = True
XFORMERS_IS_AVAILABLE = False XFORMERS_IS_AVAILABLE = False
VAE_DTYPES = [torch.float32]
try: try:
if is_nvidia(): if is_nvidia():
if int(torch_version[0]) >= 2: if int(torch_version[0]) >= 2:
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False: if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
ENABLE_PYTORCH_ATTENTION = True ENABLE_PYTORCH_ATTENTION = True
if torch.cuda.is_bf16_supported() and torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 8:
VAE_DTYPES = [torch.bfloat16] + VAE_DTYPES
if is_intel_xpu(): if is_intel_xpu():
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False: if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
ENABLE_PYTORCH_ATTENTION = True ENABLE_PYTORCH_ATTENTION = True
except: except:
pass pass
if is_intel_xpu():
VAE_DTYPES = [torch.bfloat16] + VAE_DTYPES
if args.cpu_vae:
VAE_DTYPES = [torch.float32]
if ENABLE_PYTORCH_ATTENTION: if ENABLE_PYTORCH_ATTENTION:
torch.backends.cuda.enable_math_sdp(True) torch.backends.cuda.enable_math_sdp(True)
torch.backends.cuda.enable_flash_sdp(True) torch.backends.cuda.enable_flash_sdp(True)
@ -767,7 +763,6 @@ def vae_offload_device():
return torch.device("cpu") return torch.device("cpu")
def vae_dtype(device=None, allowed_dtypes=[]): def vae_dtype(device=None, allowed_dtypes=[]):
global VAE_DTYPES
if args.fp16_vae: if args.fp16_vae:
return torch.float16 return torch.float16
elif args.bf16_vae: elif args.bf16_vae:
@ -776,12 +771,14 @@ def vae_dtype(device=None, allowed_dtypes=[]):
return torch.float32 return torch.float32
for d in allowed_dtypes: for d in allowed_dtypes:
if d == torch.float16 and should_use_fp16(device, prioritize_performance=False): if d == torch.float16 and should_use_fp16(device):
return d
if d in VAE_DTYPES:
return d return d
return VAE_DTYPES[0] # NOTE: bfloat16 seems to work on AMD for the VAE but is extremely slow in some cases compared to fp32
if d == torch.bfloat16 and (not is_amd()) and should_use_bf16(device):
return d
return torch.float32
def get_autocast_device(dev): def get_autocast_device(dev):
if hasattr(dev, 'type'): if hasattr(dev, 'type'):
@ -902,14 +899,19 @@ def pytorch_attention_flash_attention():
return True return True
return False return False
def mac_version():
try:
return tuple(int(n) for n in platform.mac_ver()[0].split("."))
except:
return None
def force_upcast_attention_dtype(): def force_upcast_attention_dtype():
upcast = args.force_upcast_attention upcast = args.force_upcast_attention
try:
macos_version = tuple(int(n) for n in platform.mac_ver()[0].split(".")) macos_version = mac_version()
if (14, 5) <= macos_version <= (15, 2): # black image bug on recent versions of macOS if macos_version is not None and ((14, 5) <= macos_version <= (15, 2)): # black image bug on recent versions of macOS
upcast = True upcast = True
except:
pass
if upcast: if upcast:
return torch.float32 return torch.float32
else: else:
@ -980,17 +982,13 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
if FORCE_FP16: if FORCE_FP16:
return True return True
if device is not None:
if is_device_mps(device):
return True
if FORCE_FP32: if FORCE_FP32:
return False return False
if directml_enabled: if directml_enabled:
return False return False
if mps_mode(): if (device is not None and is_device_mps(device)) or mps_mode():
return True return True
if cpu_mode(): if cpu_mode():
@ -1039,17 +1037,15 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
if is_device_cpu(device): #TODO ? bf16 works on CPU but is extremely slow if is_device_cpu(device): #TODO ? bf16 works on CPU but is extremely slow
return False return False
if device is not None:
if is_device_mps(device):
return True
if FORCE_FP32: if FORCE_FP32:
return False return False
if directml_enabled: if directml_enabled:
return False return False
if mps_mode(): if (device is not None and is_device_mps(device)) or mps_mode():
if mac_version() < (14,):
return False
return True return True
if cpu_mode(): if cpu_mode():

View File

@ -255,9 +255,10 @@ def fp8_linear(self, input):
tensor_2d = True tensor_2d = True
input = input.unsqueeze(1) input = input.unsqueeze(1)
input_shape = input.shape
input_dtype = input.dtype
if len(input.shape) == 3: if len(input.shape) == 3:
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input.dtype) w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype)
w = w.t() w = w.t()
scale_weight = self.scale_weight scale_weight = self.scale_weight
@ -269,23 +270,24 @@ def fp8_linear(self, input):
if scale_input is None: if scale_input is None:
scale_input = torch.ones((), device=input.device, dtype=torch.float32) scale_input = torch.ones((), device=input.device, dtype=torch.float32)
inn = torch.clamp(input, min=-448, max=448).reshape(-1, input.shape[2]).to(dtype) input = torch.clamp(input, min=-448, max=448, out=input)
input = input.reshape(-1, input_shape[2]).to(dtype)
else: else:
scale_input = scale_input.to(input.device) scale_input = scale_input.to(input.device)
inn = (input * (1.0 / scale_input).to(input.dtype)).reshape(-1, input.shape[2]).to(dtype) input = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype)
if bias is not None: if bias is not None:
o = torch._scaled_mm(inn, w, out_dtype=input.dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight) o = torch._scaled_mm(input, w, out_dtype=input_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
else: else:
o = torch._scaled_mm(inn, w, out_dtype=input.dtype, scale_a=scale_input, scale_b=scale_weight) o = torch._scaled_mm(input, w, out_dtype=input_dtype, scale_a=scale_input, scale_b=scale_weight)
if isinstance(o, tuple): if isinstance(o, tuple):
o = o[0] o = o[0]
if tensor_2d: if tensor_2d:
return o.reshape(input.shape[0], -1) return o.reshape(input_shape[0], -1)
return o.reshape((-1, input.shape[1], self.weight.shape[0])) return o.reshape((-1, input_shape[1], self.weight.shape[0]))
return None return None

View File

@ -111,7 +111,7 @@ class CLIP:
model_management.load_models_gpu([self.patcher], force_full_load=True) model_management.load_models_gpu([self.patcher], force_full_load=True)
self.layer_idx = None self.layer_idx = None
self.use_clip_schedule = False self.use_clip_schedule = False
logging.debug("CLIP model load device: {}, offload device: {}, current: {}".format(load_device, offload_device, params['device'])) logging.info("CLIP model load device: {}, offload device: {}, current: {}, dtype: {}".format(load_device, offload_device, params['device'], dtype))
def clone(self): def clone(self):
n = CLIP(no_init=True) n = CLIP(no_init=True)
@ -402,7 +402,7 @@ class VAE:
self.output_device = model_management.intermediate_device() self.output_device = model_management.intermediate_device()
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device) self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
logging.debug("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype)) logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
def vae_encode_crop_pixels(self, pixels): def vae_encode_crop_pixels(self, pixels):
downscale_ratio = self.spacial_compression_encode() downscale_ratio = self.spacial_compression_encode()