diff --git a/README.md b/README.md index eb051d0e3..bc69cce1c 100644 --- a/README.md +++ b/README.md @@ -325,6 +325,17 @@ You can enable experimental memory efficient attention on pytorch 2.5 in ComfyUI You can also try setting this env variable `PYTORCH_TUNABLEOP_ENABLED=1` which might speed things up at the cost of a very slow initial run. +###### Notes for Ascend NPU Users + +These instructions from upstream have not yet been validated. + +For models compatible with Ascend Extension for PyTorch (`torch_npu`). To get started, ensure your environment meets the prerequisites outlined on the [installation](https://ascend.github.io/docs/sources/ascend/quick_install.html) page. Here's a step-by-step guide tailored to your platform and installation method: + +1. Begin by installing the recommended or newer kernel version for Linux as specified in the Installation page of torch-npu, if necessary. +2. Proceed with the installation of Ascend Basekit, which includes the driver, firmware, and CANN, following the instructions provided for your specific platform. +3. Next, install the necessary packages for torch-npu by adhering to the platform-specific instructions on the [Installation](https://ascend.github.io/docs/sources/pytorch/install.html#pytorch) page. +4. Finally, adhere to the [ComfyUI manual installation](#manual-install-windows-linux) guide for Linux. Once all components are installed, you can run ComfyUI as described earlier. + ## Manual Install (Windows, Linux, macOS) For Development 1. Clone this repo: diff --git a/comfy/app/logger.py b/comfy/app/logger.py index 53cdf1c36..c5dc32308 100644 --- a/comfy/app/logger.py +++ b/comfy/app/logger.py @@ -52,7 +52,7 @@ def on_flush(callback): stderr_interceptor.on_flush(callback) -def setup_logger(log_level: str = 'INFO', capacity: int = 300): +def setup_logger(log_level: str = 'INFO', capacity: int = 300, use_stdout: bool = False): global logs if logs: return @@ -71,4 +71,15 @@ def setup_logger(log_level: str = 'INFO', capacity: int = 300): stream_handler = logging.StreamHandler() stream_handler.setFormatter(logging.Formatter("%(message)s")) + + if use_stdout: + # Only errors and critical to stderr + stream_handler.addFilter(lambda record: not record.levelno < logging.ERROR) + + # Lesser to stdout + stdout_handler = logging.StreamHandler(sys.stdout) + stdout_handler.setFormatter(logging.Formatter("%(message)s")) + stdout_handler.addFilter(lambda record: record.levelno < logging.ERROR) + logger.addHandler(stdout_handler) + logger.addHandler(stream_handler) diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 529d92383..be3902af2 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -84,8 +84,9 @@ def _create_parser() -> EnhancedConfigArgParser: parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.") + parser.add_argument("--oneapi-device-selector", type=str, default=None, metavar="SELECTOR_STRING", help="Sets the oneAPI device(s) this instance will use.") parser.add_argument("--disable-ipex-optimize", action="store_true", - help="Disables ipex.optimize when loading models with Intel GPUs.") + help="Disables ipex.optimize default when loading models with Intel's Extension for Pytorch.") parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.Auto, help="Default preview method for sampler nodes.", action=EnumAction) @@ -139,6 +140,7 @@ def _create_parser() -> EnhancedConfigArgParser: parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.") parser.add_argument("--create-directories", action="store_true", help="Creates the default models/, input/, output/ and temp/ directories, then exits.") + parser.add_argument("--log-stdout", action="store_true", help="Send normal process output to stdout instead of stderr (default).") parser.add_argument("--plausible-analytics-base-url", required=False, help="Enables server-side analytics events sent to the provided URL.") diff --git a/comfy/cli_args_types.py b/comfy/cli_args_types.py index df0aee39c..baa86ef09 100644 --- a/comfy/cli_args_types.py +++ b/comfy/cli_args_types.py @@ -72,6 +72,7 @@ class Configuration(dict): fp8_e5m2_text_enc (bool): Use FP8 precision for the text encoder (e5m2 variant). fp16_text_enc (bool): Use FP16 precision for the text encoder. fp32_text_enc (bool): Use FP32 precision for the text encoder. + openapi_device_selector (Optional[str]): Sets the oneAPI device(s) this instance will use. directml (Optional[int]): Use DirectML. -1 for auto-selection. disable_ipex_optimize (bool): Disable IPEX optimization for Intel GPUs. preview_method (LatentPreviewMethod): Method for generating previews. Defaults to "auto". @@ -118,6 +119,7 @@ class Configuration(dict): preview_size (int): Sets the maximum preview size for sampler nodes. Defaults to 512. openai_api_key (str): Configures the OpenAI API Key for the OpenAI nodes user_directory (Optional[str]): Set the ComfyUI user directory with an absolute path. + log_stdout (bool): Send normal process output to stdout instead of stderr (default) """ def __init__(self, **kwargs): @@ -198,6 +200,8 @@ class Configuration(dict): self.force_hf_local_dir_mode = False self.preview_size: int = 512 self.logging_level: str = "INFO" + self.openapi_device_selector: Optional[str] = None + self.log_stdout: bool = False # from guill self.cache_lru: int = 0 diff --git a/comfy/cmd/main_pre.py b/comfy/cmd/main_pre.py index 6e6f443ac..980049414 100644 --- a/comfy/cmd/main_pre.py +++ b/comfy/cmd/main_pre.py @@ -30,6 +30,8 @@ from ..tracing_compatibility import ProgressSpanSampler from ..tracing_compatibility import patch_spanbuilder_set_channel from ..vendor.aiohttp_server_instrumentation import AioHttpServerInstrumentor +this_logger = logging.getLogger(__name__) + options.enable_args_parsing() if os.name == "nt": logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage()) @@ -43,12 +45,17 @@ from ..cli_args import args if args.cuda_device is not None: os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device) os.environ['HIP_VISIBLE_DEVICES'] = str(args.cuda_device) - logging.info("Set cuda device to: {}".format(args.cuda_device)) + this_logger.info("Set cuda device to: {}".format(args.cuda_device)) if args.deterministic: if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ: os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8" +if args.oneapi_device_selector is not None: + os.environ['ONEAPI_DEVICE_SELECTOR'] = args.oneapi_device_selector + this_logger.info("Set oneapi device selector to: {}".format(args.oneapi_device_selector)) + + try: from . import cuda_malloc except Exception: @@ -76,11 +83,11 @@ def _fix_pytorch_240(): try: _ = ctypes.cdll.LoadLibrary(test_file) except FileNotFoundError: - logging.warning("Detected pytorch version with libomp issue, trying to patch") + this_logger.warning("Detected pytorch version with libomp issue, trying to patch") try: shutil.copyfile(os.path.join(lib_folder, "libiomp5md.dll"), dest) except Exception as exc_info: - logging.error("While trying to patch a fix for torch 2.4.0, an error occurred, which means this is unlikely to work", exc_info=exc_info) + this_logger.error("While trying to patch a fix for torch 2.4.0, an error occurred, which means this is unlikely to work", exc_info=exc_info) except: pass diff --git a/comfy/cmd/server.py b/comfy/cmd/server.py index 0024e536f..0a01c0f39 100644 --- a/comfy/cmd/server.py +++ b/comfy/cmd/server.py @@ -939,9 +939,7 @@ class PromptServer(ExecutorToClientProgress): self.app.add_routes(self.routes) for name, dir in self.nodes.EXTENSION_WEB_DIRS.items(): - self.app.add_routes([ - web.static('/extensions/' + quote(name), dir, follow_symlinks=True), - ]) + self.app.add_routes([web.static('/extensions/' + name, dir, follow_symlinks=True)]) self.app.add_routes([ web.static('/', self.web_root, follow_symlinks=True), diff --git a/comfy/extra_samplers/uni_pc.py b/comfy/extra_samplers/uni_pc.py index 6e83959fc..1be6274ff 100644 --- a/comfy/extra_samplers/uni_pc.py +++ b/comfy/extra_samplers/uni_pc.py @@ -80,7 +80,7 @@ class NoiseScheduleVP: 'linear' or 'cosine' for continuous-time DPMs. Returns: A wrapper object of the forward SDE (VP type). - + =============================================================== Example: @@ -208,7 +208,7 @@ def model_wrapper( arXiv preprint arXiv:2202.00512 (2022). [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models." arXiv preprint arXiv:2210.02303 (2022). - + 4. "score": marginal score function. (Trained by denoising score matching). Note that the score function and the noise prediction model follows a simple relationship: ``` @@ -245,7 +245,7 @@ def model_wrapper( [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance." arXiv preprint arXiv:2207.12598 (2022). - + The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999) or continuous-time labels (i.e. epsilon to T). @@ -623,7 +623,7 @@ class UniPC: B_h = torch.expm1(hh) else: raise NotImplementedError() - + for i in range(1, order + 1): R.append(torch.pow(rks, i - 1)) b.append(h_phi_k * factorial_i / B_h) @@ -874,4 +874,4 @@ def sample_unipc(model, noise, sigmas, extra_args=None, callback=None, disable=F return x def sample_unipc_bh2(model, noise, sigmas, extra_args=None, callback=None, disable=False): - return sample_unipc(model, noise, sigmas, extra_args, callback, disable, variant='bh2') \ No newline at end of file + return sample_unipc(model, noise, sigmas, extra_args, callback, disable, variant='bh2') diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index c6e4a68c8..1ab747f56 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -1274,7 +1274,7 @@ def sample_dpmpp_2m_cfg_pp(model, x, sigmas, extra_args=None, callback=None, dis nonlocal uncond_denoised uncond_denoised = args["uncond_denoised"] return args["denoised"] - + model_options = extra_args.get("model_options", {}).copy() extra_args["model_options"] = model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True) diff --git a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py index 1488ff2a6..4fe52dab8 100644 --- a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py +++ b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py @@ -377,7 +377,7 @@ class Decoder(nn.Module): assert ( timestep is not None ), "should pass timestep with timestep_conditioning=True" - scaled_timestep = timestep * self.timestep_scale_multiplier + scaled_timestep = timestep * self.timestep_scale_multiplier.to(dtype=sample.dtype, device=sample.device) for up_block in self.up_blocks: if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D): @@ -402,7 +402,7 @@ class Decoder(nn.Module): ) ada_values = self.last_scale_shift_table[ None, ..., None, None, None - ] + embedded_timestep.reshape( + ].to(device=sample.device, dtype=sample.dtype) + embedded_timestep.reshape( batch_size, 2, -1, @@ -696,7 +696,7 @@ class ResnetBlock3D(nn.Module): ), "should pass timestep with timestep_conditioning=True" ada_values = self.scale_shift_table[ None, ..., None, None, None - ] + timestep.reshape( + ].to(device=hidden_states.device, dtype=hidden_states.dtype) + timestep.reshape( batch_size, 4, -1, @@ -714,7 +714,7 @@ class ResnetBlock3D(nn.Module): if self.inject_noise: hidden_states = self._feed_spatial_noise( - hidden_states, self.per_channel_scale1 + hidden_states, self.per_channel_scale1.to(device=hidden_states.device, dtype=hidden_states.dtype) ) hidden_states = self.norm2(hidden_states) @@ -730,7 +730,7 @@ class ResnetBlock3D(nn.Module): if self.inject_noise: hidden_states = self._feed_spatial_noise( - hidden_states, self.per_channel_scale2 + hidden_states, self.per_channel_scale2.to(device=hidden_states.device, dtype=hidden_states.dtype) ) input_tensor = self.norm3(input_tensor) diff --git a/comfy/ldm/modules/sub_quadratic_attention.py b/comfy/ldm/modules/sub_quadratic_attention.py index 8c4c6efbd..86ab29462 100644 --- a/comfy/ldm/modules/sub_quadratic_attention.py +++ b/comfy/ldm/modules/sub_quadratic_attention.py @@ -261,7 +261,7 @@ def efficient_dot_product_attention( value=value, mask=mask, ) - + # TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance, # and pass slices to be mutated, instead of torch.cat()ing the returned slices res = torch.cat([ diff --git a/comfy/ldm/pixart/blocks.py b/comfy/ldm/pixart/blocks.py index 4d3d537ec..09c87133c 100644 --- a/comfy/ldm/pixart/blocks.py +++ b/comfy/ldm/pixart/blocks.py @@ -6,16 +6,16 @@ import torch.nn as nn import torch.nn.functional as F from einops import rearrange -from ... import model_management -from ..modules.diffusionmodules.mmdit import TimestepEmbedder, Mlp, timestep_embedding from ..modules.attention import optimized_attention +from ..modules.diffusionmodules.mmdit import TimestepEmbedder, Mlp, timestep_embedding -if model_management.xformers_enabled(): - import xformers.ops - if int((xformers.__version__).split(".")[2]) >= 28: - block_diagonal_mask_from_seqlens = xformers.ops.fmha.attn_bias.BlockDiagonalMask.from_seqlens - else: - block_diagonal_mask_from_seqlens = xformers.ops.fmha.BlockDiagonalMask.from_seqlens + +# if model_management.xformers_enabled(): +# import xformers.ops +# if int((xformers.__version__).split(".")[2].split("+")[0]) >= 28: +# block_diagonal_mask_from_seqlens = xformers.ops.fmha.attn_bias.BlockDiagonalMask.from_seqlens +# else: +# block_diagonal_mask_from_seqlens = xformers.ops.fmha.BlockDiagonalMask.from_seqlens def modulate(x, shift, scale): return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) diff --git a/comfy/ldm/pixart/pixartms.py b/comfy/ldm/pixart/pixartms.py index 4febbe102..a4a9e592e 100644 --- a/comfy/ldm/pixart/pixartms.py +++ b/comfy/ldm/pixart/pixartms.py @@ -223,7 +223,7 @@ class PixArtMS(nn.Module): if self.micro_conditioning: if c_size is None: c_size = torch.tensor([H*8, W*8], dtype=x.dtype, device=x.device).repeat(B, 1) - + if c_ar is None: c_ar = torch.tensor([H/W], dtype=x.dtype, device=x.device).repeat(B, 1) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index b8c91eb3c..cbf987ed3 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -577,7 +577,6 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None): 'transformer_depth': [0, 1, 1], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': -2, 'use_linear_in_transformer': False, 'context_dim': 768, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 1, 1, 1, 1], 'use_temporal_attention': False, 'use_temporal_resblock': False} - SD15_diffusers_inpaint = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'adm_in_channels': None, 'dtype': dtype, 'in_channels': 9, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0], 'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768, 'num_heads': 8, diff --git a/comfy/model_management.py b/comfy/model_management.py index c60a627ab..718920cd7 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -98,7 +98,7 @@ try: import intel_extension_for_pytorch as ipex # pylint: disable=import-error _ = torch.xpu.device_count() - xpu_available = torch.xpu.is_available() + xpu_available = xpu_available or torch.xpu.is_available() except: xpu_available = xpu_available or (hasattr(torch, "xpu") and torch.xpu.is_available()) @@ -109,6 +109,14 @@ try: except: pass +try: + import torch_npu # noqa: F401 + + _ = torch.npu.device_count() + npu_available = torch.npu.is_available() +except: + npu_available = False + if args.cpu: cpu_state = CPUState.CPU @@ -122,6 +130,13 @@ def is_intel_xpu(): return False +def is_ascend_npu(): + global npu_available + if npu_available: + return True + return False + + def get_torch_device(): global directml_device global cpu_state @@ -134,6 +149,8 @@ def get_torch_device(): else: if is_intel_xpu(): return torch.device("xpu", torch.xpu.current_device()) + elif is_ascend_npu(): + return torch.device("npu", torch.npu.current_device()) else: try: return torch.device(f"cuda:{torch.cuda.current_device()}") @@ -147,6 +164,7 @@ def get_total_memory(dev=None, torch_total_too=False): if dev is None: dev = get_torch_device() + mem_total_torch = 0 if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'): mem_total = psutil.virtual_memory().total mem_total_torch = mem_total @@ -156,7 +174,12 @@ def get_total_memory(dev=None, torch_total_too=False): mem_total_torch = mem_total elif is_intel_xpu(): mem_total = torch.xpu.get_device_properties(dev).total_memory - mem_total_torch = mem_total + elif is_ascend_npu(): + stats = torch.npu.memory_stats(dev) + mem_reserved = stats['reserved_bytes.all.current'] + _, mem_total_npu = torch.npu.mem_get_info(dev) + mem_total_torch = mem_reserved + mem_total = mem_total_npu else: stats = torch.cuda.memory_stats(dev) mem_reserved = stats['reserved_bytes.all.current'] @@ -234,6 +257,14 @@ def is_amd(): return False +def is_amd(): + global cpu_state + if cpu_state == CPUState.GPU: + if torch.version.hip: + return True + return False + + MIN_WEIGHT_MEMORY_RATIO = 0.4 if is_nvidia(): MIN_WEIGHT_MEMORY_RATIO = 0.2 @@ -243,32 +274,28 @@ if args.use_pytorch_cross_attention: ENABLE_PYTORCH_ATTENTION = True XFORMERS_IS_AVAILABLE = False -VAE_DTYPES = [torch.float32] - try: if is_nvidia() or is_amd(): if int(torch_version[0]) >= 2: if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False: ENABLE_PYTORCH_ATTENTION = True - if torch.cuda.is_bf16_supported() and torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 8: - VAE_DTYPES = [torch.bfloat16] + VAE_DTYPES - if is_intel_xpu(): + if is_intel_xpu() or is_ascend_npu(): if args.use_split_cross_attention == False and args.use_quad_cross_attention == False: ENABLE_PYTORCH_ATTENTION = True except: pass -if is_intel_xpu(): - VAE_DTYPES = [torch.bfloat16] + VAE_DTYPES - -if args.cpu_vae: - VAE_DTYPES = [torch.float32] - if ENABLE_PYTORCH_ATTENTION: torch.backends.cuda.enable_math_sdp(True) torch.backends.cuda.enable_flash_sdp(True) torch.backends.cuda.enable_mem_efficient_sdp(True) +try: + if int(torch_version[0]) == 2 and int(torch_version[2]) >= 5: + torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True) +except: + logging.warning("Warning, could not set allow_fp16_bf16_reduction_math_sdp") + if args.lowvram: set_vram_to = VRAMState.LOW_VRAM lowvram_available = True @@ -322,6 +349,8 @@ def get_torch_device_name(device): return "{}".format(device.type) elif is_intel_xpu(): return "{} {}".format(device, torch.xpu.get_device_name(device)) + elif is_ascend_npu(): + return "{} {}".format(device, torch.npu.get_device_name(device)) else: return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device)) @@ -605,7 +634,7 @@ def _load_models_gpu(models: Sequence[ModelManageable], memory_required: int = 0 lowvram_model_memory = 0 if vram_set_state == VRAMState.NO_VRAM: - lowvram_model_memory = 64 * 1024 * 1024 + lowvram_model_memory = 0.1 loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights) current_loaded_models.insert(0, loaded_model) @@ -852,7 +881,6 @@ def vae_offload_device(): def vae_dtype(device=None, allowed_dtypes=[]): - global VAE_DTYPES if args.fp16_vae: return torch.float16 elif args.bf16_vae: @@ -861,12 +889,14 @@ def vae_dtype(device=None, allowed_dtypes=[]): return torch.float32 for d in allowed_dtypes: - if d == torch.float16 and should_use_fp16(device, prioritize_performance=False): - return d - if d in VAE_DTYPES: + if d == torch.float16 and should_use_fp16(device): return d - return VAE_DTYPES[0] + # NOTE: bfloat16 seems to work on AMD for the VAE but is extremely slow in some cases compared to fp32 + if d == torch.bfloat16 and (not is_amd()) and should_use_bf16(device): + return d + + return torch.float32 def get_autocast_device(dev): @@ -984,6 +1014,8 @@ def xformers_enabled(): return False if is_intel_xpu(): return False + if is_ascend_npu(): + return False if directml_device: return False return XFORMERS_IS_AVAILABLE @@ -1022,17 +1054,25 @@ def pytorch_attention_flash_attention(): return True if is_intel_xpu(): return True + if is_ascend_npu(): + return True return False +def mac_version() -> Optional[tuple[int, ...]]: + try: + return tuple(int(n) for n in platform.mac_ver()[0].split(".")) + except: + return None + + def force_upcast_attention_dtype(): upcast = args.force_upcast_attention - try: - macos_version = tuple(int(n) for n in platform.mac_ver()[0].split(".")) - if (14, 5) <= macos_version <= (15, 2): # black image bug on recent versions of macOS - upcast = True - except: - pass + + macos_version = mac_version() + if macos_version is not None and ((14, 5) <= macos_version <= (15, 2)): # black image bug on recent versions of macOS + upcast = True + if upcast: return torch.float32 else: @@ -1052,8 +1092,19 @@ def get_free_memory(dev=None, torch_free_too=False): mem_free_total = 1024 * 1024 * 1024 # TODO mem_free_torch = mem_free_total elif is_intel_xpu(): - mem_free_total = torch.xpu.get_device_properties(dev).total_memory - mem_free_torch = mem_free_total + stats = torch.xpu.memory_stats(dev) + mem_active = stats['active_bytes.all.current'] + mem_reserved = stats['reserved_bytes.all.current'] + mem_free_torch = mem_reserved - mem_active + mem_free_xpu = torch.xpu.get_device_properties(dev).total_memory - mem_reserved + mem_free_total = mem_free_xpu + mem_free_torch + elif is_ascend_npu(): + stats = torch.npu.memory_stats(dev) + mem_active = stats['active_bytes.all.current'] + mem_reserved = stats['reserved_bytes.all.current'] + mem_free_npu, _ = torch.npu.mem_get_info(dev) + mem_free_torch = mem_reserved - mem_active + mem_free_total = mem_free_npu + mem_free_torch else: stats = torch.cuda.memory_stats(dev) mem_active = stats['active_bytes.all.current'] @@ -1107,17 +1158,13 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma if FORCE_FP16: return True - if device is not None: - if is_device_mps(device): - return True - if FORCE_FP32: return False if directml_device: return False - if mps_mode(): + if (device is not None and is_device_mps(device)) or mps_mode(): return True if cpu_mode(): @@ -1126,6 +1173,9 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma if is_intel_xpu(): return True + if is_ascend_npu(): + return True + if is_amd(): return True try: @@ -1176,17 +1226,15 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma if is_device_cpu(device): # TODO ? bf16 works on CPU but is extremely slow return False - if device is not None: - if is_device_mps(device): - return True - if FORCE_FP32: return False if directml_device: return False - if mps_mode(): + if (device is not None and is_device_mps(device)) or mps_mode(): + if mac_version() < (14,): + return False return True if cpu_mode(): @@ -1244,15 +1292,21 @@ def supports_fp8_compute(device=None): def soft_empty_cache(force=False): with model_management_lock: - global cpu_state - if cpu_state == CPUState.MPS: - torch.mps.empty_cache() - elif is_intel_xpu(): - torch.xpu.empty_cache() # pylint: disable=no-member - elif torch.cuda.is_available(): - if force or is_nvidia(): # This seems to make things worse on ROCm so I only do it for cuda - torch.cuda.empty_cache() - torch.cuda.ipc_collect() + _soft_empty_cache(force=force) + + +def _soft_empty_cache(force=False): + global cpu_state + if cpu_state == CPUState.MPS: + torch.mps.empty_cache() # pylint: disable=no-member + elif is_intel_xpu(): + torch.xpu.empty_cache() # pylint: disable=no-member + elif is_ascend_npu(): + torch.npu.empty_cache() # pylint: disable=no-member + elif torch.cuda.is_available(): + if force or is_nvidia(): # This seems to make things worse on ROCm so I only do it for cuda + torch.cuda.empty_cache() + torch.cuda.ipc_collect() def unload_all_models(): diff --git a/comfy/nodes/base_nodes.py b/comfy/nodes/base_nodes.py index d15f7242b..693c8e3ed 100644 --- a/comfy/nodes/base_nodes.py +++ b/comfy/nodes/base_nodes.py @@ -292,17 +292,29 @@ class VAEDecodeTiled: return {"required": {"samples": ("LATENT", ), "vae": ("VAE", ), "tile_size": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 32}), "overlap": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 32}), + "temporal_size": ("INT", {"default": 64, "min": 8, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to decode at a time."}), + "temporal_overlap": ("INT", {"default": 8, "min": 4, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to overlap."}), }} RETURN_TYPES = ("IMAGE",) FUNCTION = "decode" CATEGORY = "_for_testing" - def decode(self, vae, samples, tile_size, overlap=64): + def decode(self, vae, samples, tile_size, overlap=64, temporal_size=64, temporal_overlap=8): if tile_size < overlap * 4: overlap = tile_size // 4 + if temporal_size < temporal_overlap * 2: + temporal_overlap = temporal_overlap // 2 + temporal_compression = vae.temporal_compression_decode() + if temporal_compression is not None: + temporal_size = max(2, temporal_size // temporal_compression) + temporal_overlap = min(1, temporal_size // 2, temporal_overlap // temporal_compression) + else: + temporal_size = None + temporal_overlap = None + compression = vae.spacial_compression_decode() - images = vae.decode_tiled(samples["samples"], tile_x=tile_size // compression, tile_y=tile_size // compression, overlap=overlap // compression) + images = vae.decode_tiled(samples["samples"], tile_x=tile_size // compression, tile_y=tile_size // compression, overlap=overlap // compression, tile_t=temporal_size, overlap_t=temporal_overlap) if len(images.shape) == 5: #Combine batches images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1]) return (images, ) @@ -326,15 +338,17 @@ class VAEEncodeTiled: return {"required": {"pixels": ("IMAGE", ), "vae": ("VAE", ), "tile_size": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 64}), "overlap": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 32}), + "temporal_size": ("INT", {"default": 64, "min": 8, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to encode at a time."}), + "temporal_overlap": ("INT", {"default": 8, "min": 4, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to overlap."}), }} RETURN_TYPES = ("LATENT",) FUNCTION = "encode" CATEGORY = "_for_testing" - def encode(self, vae, pixels, tile_size, overlap): - t = vae.encode_tiled(pixels[:,:,:,:3], tile_x=tile_size, tile_y=tile_size, overlap=overlap) - return ({"samples":t}, ) + def encode(self, vae, pixels, tile_size, overlap, temporal_size=64, temporal_overlap=8): + t = vae.encode_tiled(pixels[:,:,:,:3], tile_x=tile_size, tile_y=tile_size, overlap=overlap, tile_t=temporal_size, overlap_t=temporal_overlap) + return ({"samples": t}, ) class VAEEncodeForInpaint: @classmethod @@ -1664,7 +1678,6 @@ class LoadImage: def load_image(self, image: str) -> tuple[RGBImageBatch, MaskBatch]: image_path = folder_paths.get_annotated_filepath(image) - output_images = [] output_masks = [] w, h = None, None diff --git a/comfy/ops.py b/comfy/ops.py index fbeaef17b..8f6de810d 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -308,8 +308,10 @@ def fp8_linear(self, input): tensor_2d = True input = input.unsqueeze(1) + input_shape = input.shape + input_dtype = input.dtype if len(input.shape) == 3: - w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input.dtype) + w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype) w = w.t() scale_weight = self.scale_weight @@ -321,23 +323,24 @@ def fp8_linear(self, input): if scale_input is None: scale_input = torch.ones((), device=input.device, dtype=torch.float32) - inn = torch.clamp(input, min=-448, max=448).reshape(-1, input.shape[2]).to(dtype) + input = torch.clamp(input, min=-448, max=448, out=input) + input = input.reshape(-1, input_shape[2]).to(dtype) else: scale_input = scale_input.to(input.device) - inn = (input * (1.0 / scale_input).to(input.dtype)).reshape(-1, input.shape[2]).to(dtype) + input = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype) if bias is not None: - o = torch._scaled_mm(inn, w, out_dtype=input.dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight) + o = torch._scaled_mm(input, w, out_dtype=input_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight) else: - o = torch._scaled_mm(inn, w, out_dtype=input.dtype, scale_a=scale_input, scale_b=scale_weight) + o = torch._scaled_mm(input, w, out_dtype=input_dtype, scale_a=scale_input, scale_b=scale_weight) if isinstance(o, tuple): o = o[0] if tensor_2d: - return o.reshape(input.shape[0], -1) + return o.reshape(input_shape[0], -1) - return o.reshape((-1, input.shape[1], self.weight.shape[0])) + return o.reshape((-1, input_shape[1], self.weight.shape[0])) return None diff --git a/comfy/sample.py b/comfy/sample.py index 961ba72c4..0f63e673f 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -15,7 +15,7 @@ def prepare_noise(latent_image, seed, noise_inds=None): generator = torch.manual_seed(seed) if noise_inds is None: return torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") - + unique_inds, inverse = np.unique(noise_inds, return_inverse=True) noises = [] for i in range(unique_inds[-1]+1): diff --git a/comfy/sd.py b/comfy/sd.py index 91845b071..1b13e06d5 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -121,7 +121,7 @@ class CLIP: model_management.load_models_gpu([self.patcher], force_full_load=True) self.layer_idx = None self.use_clip_schedule = False - logger.debug("CLIP model load device: {}, offload device: {}, current: {}".format(load_device, offload_device, params['device'])) + logger.info("CLIP model load device: {}, offload device: {}, current: {}, dtype: {}".format(load_device, offload_device, params['device'], dtype)) def clone(self): n = CLIP(no_init=True) @@ -271,6 +271,9 @@ class VAE: self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0) self.working_dtypes = [torch.bfloat16, torch.float32] + self.downscale_index_formula = None + self.upscale_index_formula = None + if config is None: if "decoder.mid.block_1.mix_factor" in sd: encoder_config = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} @@ -350,7 +353,9 @@ class VAE: self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype) self.memory_used_encode = lambda shape, dtype: (1.5 * max(shape[2], 7) * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype) self.upscale_ratio = (lambda a: max(0, a * 6 - 5), 8, 8) + self.upscale_index_formula = (6, 8, 8) self.downscale_ratio = (lambda a: max(0, math.floor((a + 5) / 6)), 8, 8) + self.downscale_index_formula = (6, 8, 8) self.working_dtypes = [torch.float16, torch.float32] elif "decoder.up_blocks.0.res_blocks.0.conv1.conv.weight" in sd: # lightricks ltxv tensor_conv1 = sd["decoder.up_blocks.0.res_blocks.0.conv1.conv.weight"] @@ -365,14 +370,18 @@ class VAE: self.memory_used_decode = lambda shape, dtype: (900 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype) self.memory_used_encode = lambda shape, dtype: (70 * max(shape[2], 7) * shape[3] * shape[4]) * model_management.dtype_size(dtype) self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 32, 32) + self.upscale_index_formula = (8, 32, 32) self.downscale_ratio = (lambda a: max(0, math.floor((a + 7) / 8)), 32, 32) + self.downscale_index_formula = (8, 32, 32) self.working_dtypes = [torch.bfloat16, torch.float32] elif "decoder.conv_in.conv.weight" in sd: ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} ddconfig["conv3d"] = True ddconfig["time_compress"] = 4 self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8) + self.upscale_index_formula = (4, 8, 8) self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8) + self.downscale_index_formula = (4, 8, 8) self.latent_dim = 3 self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1] self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1]) @@ -405,7 +414,7 @@ class VAE: self.output_device = model_management.intermediate_device() self.patcher = model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device) - logger.debug("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype)) + logger.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype)) def vae_encode_crop_pixels(self, pixels): downscale_ratio = self.spacial_compression_encode() @@ -438,7 +447,7 @@ class VAE: def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)): decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float() - return self.process_output(utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device)) + return self.process_output(utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device)) def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap=64): steps = pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap) @@ -459,7 +468,7 @@ class VAE: def encode_tiled_3d(self, samples, tile_t=9999, tile_x=512, tile_y=512, overlap=(1, 64, 64)): encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float() - return utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, output_device=self.output_device) + return utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device) def decode(self, samples_in): pixel_samples = None @@ -491,7 +500,7 @@ class VAE: pixel_samples = pixel_samples.to(self.output_device).movedim(1, -1) return pixel_samples - def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None): + def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None): memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) # TODO: calculate mem required for tile load_models_gpu([self.patcher], memory_required=memory_used) dims = samples.ndim - 2 @@ -509,6 +518,13 @@ class VAE: elif dims == 2: output = self.decode_tiled_(samples, **args) elif dims == 3: + if overlap_t is None: + args["overlap"] = (1, overlap, overlap) + else: + args["overlap"] = (max(1, overlap_t), overlap, overlap) + if tile_t is not None: + args["tile_t"] = max(2, tile_t) + output = self.decode_tiled_3d(samples, **args) else: raise ValueError(f"invalid dims={dims}") @@ -546,7 +562,7 @@ class VAE: return samples - def encode_tiled(self, pixel_samples, tile_x=None, tile_y=None, overlap=None): + def encode_tiled(self, pixel_samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None): pixel_samples = self.vae_encode_crop_pixels(pixel_samples) dims = self.latent_dim pixel_samples = pixel_samples.movedim(-1, 1) @@ -571,7 +587,20 @@ class VAE: elif dims == 2: samples = self.encode_tiled_(pixel_samples, **args) elif dims == 3: - samples = self.encode_tiled_3d(pixel_samples, **args) + if tile_t is not None: + tile_t_latent = max(2, self.downscale_ratio[0](tile_t)) + else: + tile_t_latent = 9999 + args["tile_t"] = self.upscale_ratio[0](tile_t_latent) + + if overlap_t is None: + args["overlap"] = (1, overlap, overlap) + else: + args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), overlap, overlap) + maximum = pixel_samples.shape[2] + maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum)) + + samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args) else: raise ValueError(f"unsupported values dim {dims}") @@ -592,6 +621,11 @@ class VAE: except: return self.downscale_ratio + def temporal_compression_decode(self): + try: + return round(self.upscale_ratio[0](8192) / 8192) + except: + return None class StyleModel: def __init__(self, model, device="cpu"): diff --git a/comfy/supported_models.py b/comfy/supported_models.py index eeed03004..70aab3506 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -627,6 +627,8 @@ class PixArtAlpha(supported_models_base.BASE): unet_extra_config = {} latent_format = latent_formats.SD15 + memory_usage_factor = 0.5 + vae_key_prefix = ["vae."] text_encoder_key_prefix = ["text_encoders."] @@ -663,6 +665,8 @@ class HunyuanDiT(supported_models_base.BASE): latent_format = latent_formats.SDXL + memory_usage_factor = 1.3 + vae_key_prefix = ["vae."] text_encoder_key_prefix = ["text_encoders."] diff --git a/comfy/utils.py b/comfy/utils.py index f042a5e3c..bf00efda1 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -912,7 +912,7 @@ def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap): @torch.inference_mode() -def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_amount=4, out_channels=3, output_device="cpu", downscale=False, pbar=None): +def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_amount=4, out_channels=3, output_device="cpu", downscale=False, index_formulas=None, pbar=None): dims = len(tile) if not (isinstance(upscale_amount, (tuple, list))): @@ -921,6 +921,12 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am if not (isinstance(overlap, (tuple, list))): overlap = [overlap] * dims + if index_formulas is None: + index_formulas = upscale_amount + + if not (isinstance(index_formulas, (tuple, list))): + index_formulas = [index_formulas] * dims + def get_upscale(dim, val): up = upscale_amount[dim] if callable(up): @@ -935,10 +941,26 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am else: return val / up + def get_upscale_pos(dim, val): + up = index_formulas[dim] + if callable(up): + return up(val) + else: + return up * val + + def get_downscale_pos(dim, val): + up = index_formulas[dim] + if callable(up): + return up(val) + else: + return val / up + if downscale: get_scale = get_downscale + get_pos = get_downscale_pos else: get_scale = get_upscale + get_pos = get_upscale_pos def mult_list_upscale(a): out = [] @@ -970,7 +992,7 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am pos = max(0, min(s.shape[d + 2] - overlap[d], it[d])) l = min(tile[d], s.shape[d + 2] - pos) s_in = s_in.narrow(d + 2, pos, l) - upscaled.append(round(get_scale(d, pos))) + upscaled.append(round(get_pos(d, pos))) ps = function(s_in).to(output_device) mask = torch.ones_like(ps) diff --git a/comfy_extras/nodes/nodes_mask.py b/comfy_extras/nodes/nodes_mask.py index 4b532c707..12185989b 100644 --- a/comfy_extras/nodes/nodes_mask.py +++ b/comfy_extras/nodes/nodes_mask.py @@ -306,7 +306,7 @@ class FeatherMask: output[:, -y, :] *= feather_rate return (output,) - + class GrowMask: @classmethod def INPUT_TYPES(cls): @@ -317,7 +317,7 @@ class GrowMask: "tapered_corners": ("BOOLEAN", {"default": True}), }, } - + CATEGORY = "mask" RETURN_TYPES = ("MASK",) diff --git a/comfy_extras/nodes/nodes_rebatch.py b/comfy_extras/nodes/nodes_rebatch.py index 3010fbd4b..e29cb9ed1 100644 --- a/comfy_extras/nodes/nodes_rebatch.py +++ b/comfy_extras/nodes/nodes_rebatch.py @@ -40,7 +40,7 @@ class LatentRebatch: return slices, indexable[num * batch_size:] else: return slices, None - + @staticmethod def slice_batch(batch, num, batch_size): result = [LatentRebatch.get_slices(x, num, batch_size) for x in batch] @@ -81,7 +81,7 @@ class LatentRebatch: if current_batch[0].shape[0] > batch_size: num = current_batch[0].shape[0] // batch_size sliced, remainder = self.slice_batch(current_batch, num, batch_size) - + for i in range(num): output_list.append({'samples': sliced[0][i], 'noise_mask': sliced[1][i], 'batch_index': sliced[2][i]}) diff --git a/comfy_extras/nodes/nodes_tomesd.py b/comfy_extras/nodes/nodes_tomesd.py index ce7b32c77..9f77c06fc 100644 --- a/comfy_extras/nodes/nodes_tomesd.py +++ b/comfy_extras/nodes/nodes_tomesd.py @@ -40,9 +40,8 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor, return do_nothing, do_nothing gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather - + with torch.no_grad(): - hsy, wsx = h // sy, w // sx # For each sy by sx kernel, randomly assign one token to be dst and the rest src @@ -50,7 +49,7 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor, rand_idx = torch.zeros(hsy, wsx, 1, device=metric.device, dtype=torch.int64) else: rand_idx = torch.randint(sy*sx, size=(hsy, wsx, 1), device=metric.device) - + # The image might not divide sx and sy, so we need to work on a view of the top left if the idx buffer instead idx_buffer_view = torch.zeros(hsy, wsx, sy*sx, device=metric.device, dtype=torch.int64) idx_buffer_view.scatter_(dim=2, index=rand_idx, src=-torch.ones_like(rand_idx, dtype=rand_idx.dtype)) @@ -99,7 +98,7 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor, def merge(x: torch.Tensor, mode="mean") -> torch.Tensor: src, dst = split(x) n, t1, c = src.shape - + unm = gather(src, dim=-2, index=unm_idx.expand(n, t1 - r, c)) src = gather(src, dim=-2, index=src_idx.expand(n, r, c)) dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode)