mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-19 19:00:21 +08:00
Merge branch 'master' into yousef-higgsv2
This commit is contained in:
commit
5191fb27e9
@ -145,7 +145,7 @@ class PerformanceFeature(enum.Enum):
|
|||||||
CublasOps = "cublas_ops"
|
CublasOps = "cublas_ops"
|
||||||
AutoTune = "autotune"
|
AutoTune = "autotune"
|
||||||
|
|
||||||
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: fp16_accumulation fp8_matrix_mult cublas_ops")
|
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
|
||||||
|
|
||||||
parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.")
|
parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.")
|
||||||
parser.add_argument("--disable-mmap", action="store_true", help="Don't use mmap when loading safetensors.")
|
parser.add_argument("--disable-mmap", action="store_true", help="Don't use mmap when loading safetensors.")
|
||||||
|
|||||||
@ -253,7 +253,10 @@ class ControlNet(ControlBase):
|
|||||||
to_concat = []
|
to_concat = []
|
||||||
for c in self.extra_concat_orig:
|
for c in self.extra_concat_orig:
|
||||||
c = c.to(self.cond_hint.device)
|
c = c.to(self.cond_hint.device)
|
||||||
c = comfy.utils.common_upscale(c, self.cond_hint.shape[3], self.cond_hint.shape[2], self.upscale_algorithm, "center")
|
c = comfy.utils.common_upscale(c, self.cond_hint.shape[-1], self.cond_hint.shape[-2], self.upscale_algorithm, "center")
|
||||||
|
if c.ndim < self.cond_hint.ndim:
|
||||||
|
c = c.unsqueeze(2)
|
||||||
|
c = comfy.utils.repeat_to_batch_size(c, self.cond_hint.shape[2], dim=2)
|
||||||
to_concat.append(comfy.utils.repeat_to_batch_size(c, self.cond_hint.shape[0]))
|
to_concat.append(comfy.utils.repeat_to_batch_size(c, self.cond_hint.shape[0]))
|
||||||
self.cond_hint = torch.cat([self.cond_hint] + to_concat, dim=1)
|
self.cond_hint = torch.cat([self.cond_hint] + to_concat, dim=1)
|
||||||
|
|
||||||
@ -585,11 +588,18 @@ def load_controlnet_flux_instantx(sd, model_options={}):
|
|||||||
|
|
||||||
def load_controlnet_qwen_instantx(sd, model_options={}):
|
def load_controlnet_qwen_instantx(sd, model_options={}):
|
||||||
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options)
|
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(sd, model_options=model_options)
|
||||||
control_model = comfy.ldm.qwen_image.controlnet.QwenImageControlNetModel(operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
control_latent_channels = sd.get("controlnet_x_embedder.weight").shape[1]
|
||||||
|
|
||||||
|
extra_condition_channels = 0
|
||||||
|
concat_mask = False
|
||||||
|
if control_latent_channels == 68: #inpaint controlnet
|
||||||
|
extra_condition_channels = control_latent_channels - 64
|
||||||
|
concat_mask = True
|
||||||
|
control_model = comfy.ldm.qwen_image.controlnet.QwenImageControlNetModel(extra_condition_channels=extra_condition_channels, operations=operations, device=offload_device, dtype=unet_dtype, **model_config.unet_config)
|
||||||
control_model = controlnet_load_state_dict(control_model, sd)
|
control_model = controlnet_load_state_dict(control_model, sd)
|
||||||
latent_format = comfy.latent_formats.Wan21()
|
latent_format = comfy.latent_formats.Wan21()
|
||||||
extra_conds = []
|
extra_conds = []
|
||||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||||
return control
|
return control
|
||||||
|
|
||||||
def convert_mistoline(sd):
|
def convert_mistoline(sd):
|
||||||
|
|||||||
@ -22,6 +22,7 @@ from enum import Enum
|
|||||||
from comfy.cli_args import args, PerformanceFeature
|
from comfy.cli_args import args, PerformanceFeature
|
||||||
import torch
|
import torch
|
||||||
import sys
|
import sys
|
||||||
|
import importlib
|
||||||
import platform
|
import platform
|
||||||
import weakref
|
import weakref
|
||||||
import gc
|
import gc
|
||||||
@ -289,6 +290,24 @@ def is_amd():
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def amd_min_version(device=None, min_rdna_version=0):
|
||||||
|
if not is_amd():
|
||||||
|
return False
|
||||||
|
|
||||||
|
if is_device_cpu(device):
|
||||||
|
return False
|
||||||
|
|
||||||
|
arch = torch.cuda.get_device_properties(device).gcnArchName
|
||||||
|
if arch.startswith('gfx') and len(arch) == 7:
|
||||||
|
try:
|
||||||
|
cmp_rdna_version = int(arch[4]) + 2
|
||||||
|
except:
|
||||||
|
cmp_rdna_version = 0
|
||||||
|
if cmp_rdna_version >= min_rdna_version:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
MIN_WEIGHT_MEMORY_RATIO = 0.4
|
MIN_WEIGHT_MEMORY_RATIO = 0.4
|
||||||
if is_nvidia():
|
if is_nvidia():
|
||||||
MIN_WEIGHT_MEMORY_RATIO = 0.0
|
MIN_WEIGHT_MEMORY_RATIO = 0.0
|
||||||
@ -321,12 +340,13 @@ try:
|
|||||||
logging.info("AMD arch: {}".format(arch))
|
logging.info("AMD arch: {}".format(arch))
|
||||||
logging.info("ROCm version: {}".format(rocm_version))
|
logging.info("ROCm version: {}".format(rocm_version))
|
||||||
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||||
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
|
if importlib.util.find_spec('triton') is not None: # AMD efficient attention implementation depends on triton. TODO: better way of detecting if it's compiled in or not.
|
||||||
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
|
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
|
||||||
ENABLE_PYTORCH_ATTENTION = True
|
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
|
||||||
# if torch_version_numeric >= (2, 8):
|
ENABLE_PYTORCH_ATTENTION = True
|
||||||
# if any((a in arch) for a in ["gfx1201"]):
|
# if torch_version_numeric >= (2, 8):
|
||||||
# ENABLE_PYTORCH_ATTENTION = True
|
# if any((a in arch) for a in ["gfx1201"]):
|
||||||
|
# ENABLE_PYTORCH_ATTENTION = True
|
||||||
if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4):
|
if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4):
|
||||||
if any((a in arch) for a in ["gfx1201", "gfx942", "gfx950"]): # TODO: more arches
|
if any((a in arch) for a in ["gfx1201", "gfx942", "gfx950"]): # TODO: more arches
|
||||||
SUPPORT_FP8_OPS = True
|
SUPPORT_FP8_OPS = True
|
||||||
@ -905,7 +925,9 @@ def vae_dtype(device=None, allowed_dtypes=[]):
|
|||||||
|
|
||||||
# NOTE: bfloat16 seems to work on AMD for the VAE but is extremely slow in some cases compared to fp32
|
# NOTE: bfloat16 seems to work on AMD for the VAE but is extremely slow in some cases compared to fp32
|
||||||
# slowness still a problem on pytorch nightly 2.9.0.dev20250720+rocm6.4 tested on RDNA3
|
# slowness still a problem on pytorch nightly 2.9.0.dev20250720+rocm6.4 tested on RDNA3
|
||||||
if d == torch.bfloat16 and (not is_amd()) and should_use_bf16(device):
|
# also a problem on RDNA4 except fp32 is also slow there.
|
||||||
|
# This is due to large bf16 convolutions being extremely slow.
|
||||||
|
if d == torch.bfloat16 and ((not is_amd()) or amd_min_version(device, min_rdna_version=4)) and should_use_bf16(device):
|
||||||
return d
|
return d
|
||||||
|
|
||||||
return torch.float32
|
return torch.float32
|
||||||
|
|||||||
@ -140,11 +140,12 @@ def precompute_freqs_cis(head_dim, position_ids, theta, rope_dims=None, device=N
|
|||||||
|
|
||||||
|
|
||||||
def apply_rope(xq, xk, freqs_cis):
|
def apply_rope(xq, xk, freqs_cis):
|
||||||
|
org_dtype = xq.dtype
|
||||||
cos = freqs_cis[0]
|
cos = freqs_cis[0]
|
||||||
sin = freqs_cis[1]
|
sin = freqs_cis[1]
|
||||||
q_embed = (xq * cos) + (rotate_half(xq) * sin)
|
q_embed = (xq * cos) + (rotate_half(xq) * sin)
|
||||||
k_embed = (xk * cos) + (rotate_half(xk) * sin)
|
k_embed = (xk * cos) + (rotate_half(xk) * sin)
|
||||||
return q_embed, k_embed, sin, cos
|
return q_embed.to(org_dtype), k_embed.to(org_dtype), sin, cos
|
||||||
|
|
||||||
class LlamaRoPE(nn.Module):
|
class LlamaRoPE(nn.Module):
|
||||||
def __init__(self, config, device = None, dtype = None):
|
def __init__(self, config, device = None, dtype = None):
|
||||||
|
|||||||
@ -162,7 +162,12 @@ def easycache_sample_wrapper(executor, *args, **kwargs):
|
|||||||
logging.info(f"{easycache.name} [verbose] - output_change_rates {len(output_change_rates)}: {output_change_rates}")
|
logging.info(f"{easycache.name} [verbose] - output_change_rates {len(output_change_rates)}: {output_change_rates}")
|
||||||
logging.info(f"{easycache.name} [verbose] - approx_output_change_rates {len(approx_output_change_rates)}: {approx_output_change_rates}")
|
logging.info(f"{easycache.name} [verbose] - approx_output_change_rates {len(approx_output_change_rates)}: {approx_output_change_rates}")
|
||||||
total_steps = len(args[3])-1
|
total_steps = len(args[3])-1
|
||||||
logging.info(f"{easycache.name} - skipped {easycache.total_steps_skipped}/{total_steps} steps ({total_steps/(total_steps-easycache.total_steps_skipped):.2f}x speedup).")
|
# catch division by zero for log statement; sucks to crash after all sampling is done
|
||||||
|
try:
|
||||||
|
speedup = total_steps/(total_steps-easycache.total_steps_skipped)
|
||||||
|
except ZeroDivisionError:
|
||||||
|
speedup = 1.0
|
||||||
|
logging.info(f"{easycache.name} - skipped {easycache.total_steps_skipped}/{total_steps} steps ({speedup:.2f}x speedup).")
|
||||||
easycache.reset()
|
easycache.reset()
|
||||||
guider.model_options = orig_model_options
|
guider.model_options = orig_model_options
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user