mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-15 00:30:55 +08:00
Merge branch 'master' of github.com:comfyanonymous/ComfyUI
This commit is contained in:
commit
9d5a5dd533
11
README.md
11
README.md
@ -325,6 +325,17 @@ You can enable experimental memory efficient attention on pytorch 2.5 in ComfyUI
|
||||
|
||||
You can also try setting this env variable `PYTORCH_TUNABLEOP_ENABLED=1` which might speed things up at the cost of a very slow initial run.
|
||||
|
||||
###### Notes for Ascend NPU Users
|
||||
|
||||
These instructions from upstream have not yet been validated.
|
||||
|
||||
For models compatible with Ascend Extension for PyTorch (`torch_npu`). To get started, ensure your environment meets the prerequisites outlined on the [installation](https://ascend.github.io/docs/sources/ascend/quick_install.html) page. Here's a step-by-step guide tailored to your platform and installation method:
|
||||
|
||||
1. Begin by installing the recommended or newer kernel version for Linux as specified in the Installation page of torch-npu, if necessary.
|
||||
2. Proceed with the installation of Ascend Basekit, which includes the driver, firmware, and CANN, following the instructions provided for your specific platform.
|
||||
3. Next, install the necessary packages for torch-npu by adhering to the platform-specific instructions on the [Installation](https://ascend.github.io/docs/sources/pytorch/install.html#pytorch) page.
|
||||
4. Finally, adhere to the [ComfyUI manual installation](#manual-install-windows-linux) guide for Linux. Once all components are installed, you can run ComfyUI as described earlier.
|
||||
|
||||
## Manual Install (Windows, Linux, macOS) For Development
|
||||
|
||||
1. Clone this repo:
|
||||
|
||||
@ -52,7 +52,7 @@ def on_flush(callback):
|
||||
stderr_interceptor.on_flush(callback)
|
||||
|
||||
|
||||
def setup_logger(log_level: str = 'INFO', capacity: int = 300):
|
||||
def setup_logger(log_level: str = 'INFO', capacity: int = 300, use_stdout: bool = False):
|
||||
global logs
|
||||
if logs:
|
||||
return
|
||||
@ -71,4 +71,15 @@ def setup_logger(log_level: str = 'INFO', capacity: int = 300):
|
||||
|
||||
stream_handler = logging.StreamHandler()
|
||||
stream_handler.setFormatter(logging.Formatter("%(message)s"))
|
||||
|
||||
if use_stdout:
|
||||
# Only errors and critical to stderr
|
||||
stream_handler.addFilter(lambda record: not record.levelno < logging.ERROR)
|
||||
|
||||
# Lesser to stdout
|
||||
stdout_handler = logging.StreamHandler(sys.stdout)
|
||||
stdout_handler.setFormatter(logging.Formatter("%(message)s"))
|
||||
stdout_handler.addFilter(lambda record: record.levelno < logging.ERROR)
|
||||
logger.addHandler(stdout_handler)
|
||||
|
||||
logger.addHandler(stream_handler)
|
||||
|
||||
@ -84,8 +84,9 @@ def _create_parser() -> EnhancedConfigArgParser:
|
||||
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1,
|
||||
help="Use torch-directml.")
|
||||
|
||||
parser.add_argument("--oneapi-device-selector", type=str, default=None, metavar="SELECTOR_STRING", help="Sets the oneAPI device(s) this instance will use.")
|
||||
parser.add_argument("--disable-ipex-optimize", action="store_true",
|
||||
help="Disables ipex.optimize when loading models with Intel GPUs.")
|
||||
help="Disables ipex.optimize default when loading models with Intel's Extension for Pytorch.")
|
||||
|
||||
parser.add_argument("--preview-method", type=LatentPreviewMethod, default=LatentPreviewMethod.Auto,
|
||||
help="Default preview method for sampler nodes.", action=EnumAction)
|
||||
@ -139,6 +140,7 @@ def _create_parser() -> EnhancedConfigArgParser:
|
||||
parser.add_argument("--multi-user", action="store_true", help="Enables per-user storage.")
|
||||
parser.add_argument("--create-directories", action="store_true",
|
||||
help="Creates the default models/, input/, output/ and temp/ directories, then exits.")
|
||||
parser.add_argument("--log-stdout", action="store_true", help="Send normal process output to stdout instead of stderr (default).")
|
||||
|
||||
parser.add_argument("--plausible-analytics-base-url", required=False,
|
||||
help="Enables server-side analytics events sent to the provided URL.")
|
||||
|
||||
@ -72,6 +72,7 @@ class Configuration(dict):
|
||||
fp8_e5m2_text_enc (bool): Use FP8 precision for the text encoder (e5m2 variant).
|
||||
fp16_text_enc (bool): Use FP16 precision for the text encoder.
|
||||
fp32_text_enc (bool): Use FP32 precision for the text encoder.
|
||||
openapi_device_selector (Optional[str]): Sets the oneAPI device(s) this instance will use.
|
||||
directml (Optional[int]): Use DirectML. -1 for auto-selection.
|
||||
disable_ipex_optimize (bool): Disable IPEX optimization for Intel GPUs.
|
||||
preview_method (LatentPreviewMethod): Method for generating previews. Defaults to "auto".
|
||||
@ -118,6 +119,7 @@ class Configuration(dict):
|
||||
preview_size (int): Sets the maximum preview size for sampler nodes. Defaults to 512.
|
||||
openai_api_key (str): Configures the OpenAI API Key for the OpenAI nodes
|
||||
user_directory (Optional[str]): Set the ComfyUI user directory with an absolute path.
|
||||
log_stdout (bool): Send normal process output to stdout instead of stderr (default)
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
@ -198,6 +200,8 @@ class Configuration(dict):
|
||||
self.force_hf_local_dir_mode = False
|
||||
self.preview_size: int = 512
|
||||
self.logging_level: str = "INFO"
|
||||
self.openapi_device_selector: Optional[str] = None
|
||||
self.log_stdout: bool = False
|
||||
|
||||
# from guill
|
||||
self.cache_lru: int = 0
|
||||
|
||||
@ -30,6 +30,8 @@ from ..tracing_compatibility import ProgressSpanSampler
|
||||
from ..tracing_compatibility import patch_spanbuilder_set_channel
|
||||
from ..vendor.aiohttp_server_instrumentation import AioHttpServerInstrumentor
|
||||
|
||||
this_logger = logging.getLogger(__name__)
|
||||
|
||||
options.enable_args_parsing()
|
||||
if os.name == "nt":
|
||||
logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())
|
||||
@ -43,12 +45,17 @@ from ..cli_args import args
|
||||
if args.cuda_device is not None:
|
||||
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.cuda_device)
|
||||
os.environ['HIP_VISIBLE_DEVICES'] = str(args.cuda_device)
|
||||
logging.info("Set cuda device to: {}".format(args.cuda_device))
|
||||
this_logger.info("Set cuda device to: {}".format(args.cuda_device))
|
||||
|
||||
if args.deterministic:
|
||||
if 'CUBLAS_WORKSPACE_CONFIG' not in os.environ:
|
||||
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"
|
||||
|
||||
if args.oneapi_device_selector is not None:
|
||||
os.environ['ONEAPI_DEVICE_SELECTOR'] = args.oneapi_device_selector
|
||||
this_logger.info("Set oneapi device selector to: {}".format(args.oneapi_device_selector))
|
||||
|
||||
|
||||
try:
|
||||
from . import cuda_malloc
|
||||
except Exception:
|
||||
@ -76,11 +83,11 @@ def _fix_pytorch_240():
|
||||
try:
|
||||
_ = ctypes.cdll.LoadLibrary(test_file)
|
||||
except FileNotFoundError:
|
||||
logging.warning("Detected pytorch version with libomp issue, trying to patch")
|
||||
this_logger.warning("Detected pytorch version with libomp issue, trying to patch")
|
||||
try:
|
||||
shutil.copyfile(os.path.join(lib_folder, "libiomp5md.dll"), dest)
|
||||
except Exception as exc_info:
|
||||
logging.error("While trying to patch a fix for torch 2.4.0, an error occurred, which means this is unlikely to work", exc_info=exc_info)
|
||||
this_logger.error("While trying to patch a fix for torch 2.4.0, an error occurred, which means this is unlikely to work", exc_info=exc_info)
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
@ -939,9 +939,7 @@ class PromptServer(ExecutorToClientProgress):
|
||||
self.app.add_routes(self.routes)
|
||||
|
||||
for name, dir in self.nodes.EXTENSION_WEB_DIRS.items():
|
||||
self.app.add_routes([
|
||||
web.static('/extensions/' + quote(name), dir, follow_symlinks=True),
|
||||
])
|
||||
self.app.add_routes([web.static('/extensions/' + name, dir, follow_symlinks=True)])
|
||||
|
||||
self.app.add_routes([
|
||||
web.static('/', self.web_root, follow_symlinks=True),
|
||||
|
||||
@ -80,7 +80,7 @@ class NoiseScheduleVP:
|
||||
'linear' or 'cosine' for continuous-time DPMs.
|
||||
Returns:
|
||||
A wrapper object of the forward SDE (VP type).
|
||||
|
||||
|
||||
===============================================================
|
||||
|
||||
Example:
|
||||
@ -208,7 +208,7 @@ def model_wrapper(
|
||||
arXiv preprint arXiv:2202.00512 (2022).
|
||||
[2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
|
||||
arXiv preprint arXiv:2210.02303 (2022).
|
||||
|
||||
|
||||
4. "score": marginal score function. (Trained by denoising score matching).
|
||||
Note that the score function and the noise prediction model follows a simple relationship:
|
||||
```
|
||||
@ -245,7 +245,7 @@ def model_wrapper(
|
||||
|
||||
[4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
|
||||
arXiv preprint arXiv:2207.12598 (2022).
|
||||
|
||||
|
||||
|
||||
The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
|
||||
or continuous-time labels (i.e. epsilon to T).
|
||||
@ -623,7 +623,7 @@ class UniPC:
|
||||
B_h = torch.expm1(hh)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
for i in range(1, order + 1):
|
||||
R.append(torch.pow(rks, i - 1))
|
||||
b.append(h_phi_k * factorial_i / B_h)
|
||||
@ -874,4 +874,4 @@ def sample_unipc(model, noise, sigmas, extra_args=None, callback=None, disable=F
|
||||
return x
|
||||
|
||||
def sample_unipc_bh2(model, noise, sigmas, extra_args=None, callback=None, disable=False):
|
||||
return sample_unipc(model, noise, sigmas, extra_args, callback, disable, variant='bh2')
|
||||
return sample_unipc(model, noise, sigmas, extra_args, callback, disable, variant='bh2')
|
||||
|
||||
@ -1274,7 +1274,7 @@ def sample_dpmpp_2m_cfg_pp(model, x, sigmas, extra_args=None, callback=None, dis
|
||||
nonlocal uncond_denoised
|
||||
uncond_denoised = args["uncond_denoised"]
|
||||
return args["denoised"]
|
||||
|
||||
|
||||
model_options = extra_args.get("model_options", {}).copy()
|
||||
extra_args["model_options"] = model_patcher.set_model_options_post_cfg_function(model_options, post_cfg_function, disable_cfg1_optimization=True)
|
||||
|
||||
|
||||
@ -377,7 +377,7 @@ class Decoder(nn.Module):
|
||||
assert (
|
||||
timestep is not None
|
||||
), "should pass timestep with timestep_conditioning=True"
|
||||
scaled_timestep = timestep * self.timestep_scale_multiplier
|
||||
scaled_timestep = timestep * self.timestep_scale_multiplier.to(dtype=sample.dtype, device=sample.device)
|
||||
|
||||
for up_block in self.up_blocks:
|
||||
if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
|
||||
@ -402,7 +402,7 @@ class Decoder(nn.Module):
|
||||
)
|
||||
ada_values = self.last_scale_shift_table[
|
||||
None, ..., None, None, None
|
||||
] + embedded_timestep.reshape(
|
||||
].to(device=sample.device, dtype=sample.dtype) + embedded_timestep.reshape(
|
||||
batch_size,
|
||||
2,
|
||||
-1,
|
||||
@ -696,7 +696,7 @@ class ResnetBlock3D(nn.Module):
|
||||
), "should pass timestep with timestep_conditioning=True"
|
||||
ada_values = self.scale_shift_table[
|
||||
None, ..., None, None, None
|
||||
] + timestep.reshape(
|
||||
].to(device=hidden_states.device, dtype=hidden_states.dtype) + timestep.reshape(
|
||||
batch_size,
|
||||
4,
|
||||
-1,
|
||||
@ -714,7 +714,7 @@ class ResnetBlock3D(nn.Module):
|
||||
|
||||
if self.inject_noise:
|
||||
hidden_states = self._feed_spatial_noise(
|
||||
hidden_states, self.per_channel_scale1
|
||||
hidden_states, self.per_channel_scale1.to(device=hidden_states.device, dtype=hidden_states.dtype)
|
||||
)
|
||||
|
||||
hidden_states = self.norm2(hidden_states)
|
||||
@ -730,7 +730,7 @@ class ResnetBlock3D(nn.Module):
|
||||
|
||||
if self.inject_noise:
|
||||
hidden_states = self._feed_spatial_noise(
|
||||
hidden_states, self.per_channel_scale2
|
||||
hidden_states, self.per_channel_scale2.to(device=hidden_states.device, dtype=hidden_states.dtype)
|
||||
)
|
||||
|
||||
input_tensor = self.norm3(input_tensor)
|
||||
|
||||
@ -261,7 +261,7 @@ def efficient_dot_product_attention(
|
||||
value=value,
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
|
||||
# TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance,
|
||||
# and pass slices to be mutated, instead of torch.cat()ing the returned slices
|
||||
res = torch.cat([
|
||||
|
||||
@ -6,16 +6,16 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
from ... import model_management
|
||||
from ..modules.diffusionmodules.mmdit import TimestepEmbedder, Mlp, timestep_embedding
|
||||
from ..modules.attention import optimized_attention
|
||||
from ..modules.diffusionmodules.mmdit import TimestepEmbedder, Mlp, timestep_embedding
|
||||
|
||||
if model_management.xformers_enabled():
|
||||
import xformers.ops
|
||||
if int((xformers.__version__).split(".")[2]) >= 28:
|
||||
block_diagonal_mask_from_seqlens = xformers.ops.fmha.attn_bias.BlockDiagonalMask.from_seqlens
|
||||
else:
|
||||
block_diagonal_mask_from_seqlens = xformers.ops.fmha.BlockDiagonalMask.from_seqlens
|
||||
|
||||
# if model_management.xformers_enabled():
|
||||
# import xformers.ops
|
||||
# if int((xformers.__version__).split(".")[2].split("+")[0]) >= 28:
|
||||
# block_diagonal_mask_from_seqlens = xformers.ops.fmha.attn_bias.BlockDiagonalMask.from_seqlens
|
||||
# else:
|
||||
# block_diagonal_mask_from_seqlens = xformers.ops.fmha.BlockDiagonalMask.from_seqlens
|
||||
|
||||
def modulate(x, shift, scale):
|
||||
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
|
||||
|
||||
@ -223,7 +223,7 @@ class PixArtMS(nn.Module):
|
||||
if self.micro_conditioning:
|
||||
if c_size is None:
|
||||
c_size = torch.tensor([H*8, W*8], dtype=x.dtype, device=x.device).repeat(B, 1)
|
||||
|
||||
|
||||
if c_ar is None:
|
||||
c_ar = torch.tensor([H/W], dtype=x.dtype, device=x.device).repeat(B, 1)
|
||||
|
||||
|
||||
@ -577,7 +577,6 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
|
||||
'transformer_depth': [0, 1, 1], 'channel_mult': [1, 2, 4], 'transformer_depth_middle': -2, 'use_linear_in_transformer': False,
|
||||
'context_dim': 768, 'num_head_channels': 64, 'transformer_depth_output': [0, 0, 1, 1, 1, 1],
|
||||
'use_temporal_attention': False, 'use_temporal_resblock': False}
|
||||
|
||||
SD15_diffusers_inpaint = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False, 'adm_in_channels': None,
|
||||
'dtype': dtype, 'in_channels': 9, 'model_channels': 320, 'num_res_blocks': [2, 2, 2, 2], 'transformer_depth': [1, 1, 1, 1, 1, 1, 0, 0],
|
||||
'channel_mult': [1, 2, 4, 4], 'transformer_depth_middle': 1, 'use_linear_in_transformer': False, 'context_dim': 768, 'num_heads': 8,
|
||||
|
||||
@ -98,7 +98,7 @@ try:
|
||||
import intel_extension_for_pytorch as ipex # pylint: disable=import-error
|
||||
|
||||
_ = torch.xpu.device_count()
|
||||
xpu_available = torch.xpu.is_available()
|
||||
xpu_available = xpu_available or torch.xpu.is_available()
|
||||
except:
|
||||
xpu_available = xpu_available or (hasattr(torch, "xpu") and torch.xpu.is_available())
|
||||
|
||||
@ -109,6 +109,14 @@ try:
|
||||
except:
|
||||
pass
|
||||
|
||||
try:
|
||||
import torch_npu # noqa: F401
|
||||
|
||||
_ = torch.npu.device_count()
|
||||
npu_available = torch.npu.is_available()
|
||||
except:
|
||||
npu_available = False
|
||||
|
||||
if args.cpu:
|
||||
cpu_state = CPUState.CPU
|
||||
|
||||
@ -122,6 +130,13 @@ def is_intel_xpu():
|
||||
return False
|
||||
|
||||
|
||||
def is_ascend_npu():
|
||||
global npu_available
|
||||
if npu_available:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def get_torch_device():
|
||||
global directml_device
|
||||
global cpu_state
|
||||
@ -134,6 +149,8 @@ def get_torch_device():
|
||||
else:
|
||||
if is_intel_xpu():
|
||||
return torch.device("xpu", torch.xpu.current_device())
|
||||
elif is_ascend_npu():
|
||||
return torch.device("npu", torch.npu.current_device())
|
||||
else:
|
||||
try:
|
||||
return torch.device(f"cuda:{torch.cuda.current_device()}")
|
||||
@ -147,6 +164,7 @@ def get_total_memory(dev=None, torch_total_too=False):
|
||||
if dev is None:
|
||||
dev = get_torch_device()
|
||||
|
||||
mem_total_torch = 0
|
||||
if hasattr(dev, 'type') and (dev.type == 'cpu' or dev.type == 'mps'):
|
||||
mem_total = psutil.virtual_memory().total
|
||||
mem_total_torch = mem_total
|
||||
@ -156,7 +174,12 @@ def get_total_memory(dev=None, torch_total_too=False):
|
||||
mem_total_torch = mem_total
|
||||
elif is_intel_xpu():
|
||||
mem_total = torch.xpu.get_device_properties(dev).total_memory
|
||||
mem_total_torch = mem_total
|
||||
elif is_ascend_npu():
|
||||
stats = torch.npu.memory_stats(dev)
|
||||
mem_reserved = stats['reserved_bytes.all.current']
|
||||
_, mem_total_npu = torch.npu.mem_get_info(dev)
|
||||
mem_total_torch = mem_reserved
|
||||
mem_total = mem_total_npu
|
||||
else:
|
||||
stats = torch.cuda.memory_stats(dev)
|
||||
mem_reserved = stats['reserved_bytes.all.current']
|
||||
@ -234,6 +257,14 @@ def is_amd():
|
||||
return False
|
||||
|
||||
|
||||
def is_amd():
|
||||
global cpu_state
|
||||
if cpu_state == CPUState.GPU:
|
||||
if torch.version.hip:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
MIN_WEIGHT_MEMORY_RATIO = 0.4
|
||||
if is_nvidia():
|
||||
MIN_WEIGHT_MEMORY_RATIO = 0.2
|
||||
@ -243,32 +274,28 @@ if args.use_pytorch_cross_attention:
|
||||
ENABLE_PYTORCH_ATTENTION = True
|
||||
XFORMERS_IS_AVAILABLE = False
|
||||
|
||||
VAE_DTYPES = [torch.float32]
|
||||
|
||||
try:
|
||||
if is_nvidia() or is_amd():
|
||||
if int(torch_version[0]) >= 2:
|
||||
if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||
ENABLE_PYTORCH_ATTENTION = True
|
||||
if torch.cuda.is_bf16_supported() and torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 8:
|
||||
VAE_DTYPES = [torch.bfloat16] + VAE_DTYPES
|
||||
if is_intel_xpu():
|
||||
if is_intel_xpu() or is_ascend_npu():
|
||||
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||
ENABLE_PYTORCH_ATTENTION = True
|
||||
except:
|
||||
pass
|
||||
|
||||
if is_intel_xpu():
|
||||
VAE_DTYPES = [torch.bfloat16] + VAE_DTYPES
|
||||
|
||||
if args.cpu_vae:
|
||||
VAE_DTYPES = [torch.float32]
|
||||
|
||||
if ENABLE_PYTORCH_ATTENTION:
|
||||
torch.backends.cuda.enable_math_sdp(True)
|
||||
torch.backends.cuda.enable_flash_sdp(True)
|
||||
torch.backends.cuda.enable_mem_efficient_sdp(True)
|
||||
|
||||
try:
|
||||
if int(torch_version[0]) == 2 and int(torch_version[2]) >= 5:
|
||||
torch.backends.cuda.allow_fp16_bf16_reduction_math_sdp(True)
|
||||
except:
|
||||
logging.warning("Warning, could not set allow_fp16_bf16_reduction_math_sdp")
|
||||
|
||||
if args.lowvram:
|
||||
set_vram_to = VRAMState.LOW_VRAM
|
||||
lowvram_available = True
|
||||
@ -322,6 +349,8 @@ def get_torch_device_name(device):
|
||||
return "{}".format(device.type)
|
||||
elif is_intel_xpu():
|
||||
return "{} {}".format(device, torch.xpu.get_device_name(device))
|
||||
elif is_ascend_npu():
|
||||
return "{} {}".format(device, torch.npu.get_device_name(device))
|
||||
else:
|
||||
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
|
||||
|
||||
@ -605,7 +634,7 @@ def _load_models_gpu(models: Sequence[ModelManageable], memory_required: int = 0
|
||||
lowvram_model_memory = 0
|
||||
|
||||
if vram_set_state == VRAMState.NO_VRAM:
|
||||
lowvram_model_memory = 64 * 1024 * 1024
|
||||
lowvram_model_memory = 0.1
|
||||
|
||||
loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights)
|
||||
current_loaded_models.insert(0, loaded_model)
|
||||
@ -852,7 +881,6 @@ def vae_offload_device():
|
||||
|
||||
|
||||
def vae_dtype(device=None, allowed_dtypes=[]):
|
||||
global VAE_DTYPES
|
||||
if args.fp16_vae:
|
||||
return torch.float16
|
||||
elif args.bf16_vae:
|
||||
@ -861,12 +889,14 @@ def vae_dtype(device=None, allowed_dtypes=[]):
|
||||
return torch.float32
|
||||
|
||||
for d in allowed_dtypes:
|
||||
if d == torch.float16 and should_use_fp16(device, prioritize_performance=False):
|
||||
return d
|
||||
if d in VAE_DTYPES:
|
||||
if d == torch.float16 and should_use_fp16(device):
|
||||
return d
|
||||
|
||||
return VAE_DTYPES[0]
|
||||
# NOTE: bfloat16 seems to work on AMD for the VAE but is extremely slow in some cases compared to fp32
|
||||
if d == torch.bfloat16 and (not is_amd()) and should_use_bf16(device):
|
||||
return d
|
||||
|
||||
return torch.float32
|
||||
|
||||
|
||||
def get_autocast_device(dev):
|
||||
@ -984,6 +1014,8 @@ def xformers_enabled():
|
||||
return False
|
||||
if is_intel_xpu():
|
||||
return False
|
||||
if is_ascend_npu():
|
||||
return False
|
||||
if directml_device:
|
||||
return False
|
||||
return XFORMERS_IS_AVAILABLE
|
||||
@ -1022,17 +1054,25 @@ def pytorch_attention_flash_attention():
|
||||
return True
|
||||
if is_intel_xpu():
|
||||
return True
|
||||
if is_ascend_npu():
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def mac_version() -> Optional[tuple[int, ...]]:
|
||||
try:
|
||||
return tuple(int(n) for n in platform.mac_ver()[0].split("."))
|
||||
except:
|
||||
return None
|
||||
|
||||
|
||||
def force_upcast_attention_dtype():
|
||||
upcast = args.force_upcast_attention
|
||||
try:
|
||||
macos_version = tuple(int(n) for n in platform.mac_ver()[0].split("."))
|
||||
if (14, 5) <= macos_version <= (15, 2): # black image bug on recent versions of macOS
|
||||
upcast = True
|
||||
except:
|
||||
pass
|
||||
|
||||
macos_version = mac_version()
|
||||
if macos_version is not None and ((14, 5) <= macos_version <= (15, 2)): # black image bug on recent versions of macOS
|
||||
upcast = True
|
||||
|
||||
if upcast:
|
||||
return torch.float32
|
||||
else:
|
||||
@ -1052,8 +1092,19 @@ def get_free_memory(dev=None, torch_free_too=False):
|
||||
mem_free_total = 1024 * 1024 * 1024 # TODO
|
||||
mem_free_torch = mem_free_total
|
||||
elif is_intel_xpu():
|
||||
mem_free_total = torch.xpu.get_device_properties(dev).total_memory
|
||||
mem_free_torch = mem_free_total
|
||||
stats = torch.xpu.memory_stats(dev)
|
||||
mem_active = stats['active_bytes.all.current']
|
||||
mem_reserved = stats['reserved_bytes.all.current']
|
||||
mem_free_torch = mem_reserved - mem_active
|
||||
mem_free_xpu = torch.xpu.get_device_properties(dev).total_memory - mem_reserved
|
||||
mem_free_total = mem_free_xpu + mem_free_torch
|
||||
elif is_ascend_npu():
|
||||
stats = torch.npu.memory_stats(dev)
|
||||
mem_active = stats['active_bytes.all.current']
|
||||
mem_reserved = stats['reserved_bytes.all.current']
|
||||
mem_free_npu, _ = torch.npu.mem_get_info(dev)
|
||||
mem_free_torch = mem_reserved - mem_active
|
||||
mem_free_total = mem_free_npu + mem_free_torch
|
||||
else:
|
||||
stats = torch.cuda.memory_stats(dev)
|
||||
mem_active = stats['active_bytes.all.current']
|
||||
@ -1107,17 +1158,13 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
if FORCE_FP16:
|
||||
return True
|
||||
|
||||
if device is not None:
|
||||
if is_device_mps(device):
|
||||
return True
|
||||
|
||||
if FORCE_FP32:
|
||||
return False
|
||||
|
||||
if directml_device:
|
||||
return False
|
||||
|
||||
if mps_mode():
|
||||
if (device is not None and is_device_mps(device)) or mps_mode():
|
||||
return True
|
||||
|
||||
if cpu_mode():
|
||||
@ -1126,6 +1173,9 @@ def should_use_fp16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
if is_intel_xpu():
|
||||
return True
|
||||
|
||||
if is_ascend_npu():
|
||||
return True
|
||||
|
||||
if is_amd():
|
||||
return True
|
||||
try:
|
||||
@ -1176,17 +1226,15 @@ def should_use_bf16(device=None, model_params=0, prioritize_performance=True, ma
|
||||
if is_device_cpu(device): # TODO ? bf16 works on CPU but is extremely slow
|
||||
return False
|
||||
|
||||
if device is not None:
|
||||
if is_device_mps(device):
|
||||
return True
|
||||
|
||||
if FORCE_FP32:
|
||||
return False
|
||||
|
||||
if directml_device:
|
||||
return False
|
||||
|
||||
if mps_mode():
|
||||
if (device is not None and is_device_mps(device)) or mps_mode():
|
||||
if mac_version() < (14,):
|
||||
return False
|
||||
return True
|
||||
|
||||
if cpu_mode():
|
||||
@ -1244,15 +1292,21 @@ def supports_fp8_compute(device=None):
|
||||
|
||||
def soft_empty_cache(force=False):
|
||||
with model_management_lock:
|
||||
global cpu_state
|
||||
if cpu_state == CPUState.MPS:
|
||||
torch.mps.empty_cache()
|
||||
elif is_intel_xpu():
|
||||
torch.xpu.empty_cache() # pylint: disable=no-member
|
||||
elif torch.cuda.is_available():
|
||||
if force or is_nvidia(): # This seems to make things worse on ROCm so I only do it for cuda
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
_soft_empty_cache(force=force)
|
||||
|
||||
|
||||
def _soft_empty_cache(force=False):
|
||||
global cpu_state
|
||||
if cpu_state == CPUState.MPS:
|
||||
torch.mps.empty_cache() # pylint: disable=no-member
|
||||
elif is_intel_xpu():
|
||||
torch.xpu.empty_cache() # pylint: disable=no-member
|
||||
elif is_ascend_npu():
|
||||
torch.npu.empty_cache() # pylint: disable=no-member
|
||||
elif torch.cuda.is_available():
|
||||
if force or is_nvidia(): # This seems to make things worse on ROCm so I only do it for cuda
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.ipc_collect()
|
||||
|
||||
|
||||
def unload_all_models():
|
||||
|
||||
@ -292,17 +292,29 @@ class VAEDecodeTiled:
|
||||
return {"required": {"samples": ("LATENT", ), "vae": ("VAE", ),
|
||||
"tile_size": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 32}),
|
||||
"overlap": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 32}),
|
||||
"temporal_size": ("INT", {"default": 64, "min": 8, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to decode at a time."}),
|
||||
"temporal_overlap": ("INT", {"default": 8, "min": 4, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to overlap."}),
|
||||
}}
|
||||
RETURN_TYPES = ("IMAGE",)
|
||||
FUNCTION = "decode"
|
||||
|
||||
CATEGORY = "_for_testing"
|
||||
|
||||
def decode(self, vae, samples, tile_size, overlap=64):
|
||||
def decode(self, vae, samples, tile_size, overlap=64, temporal_size=64, temporal_overlap=8):
|
||||
if tile_size < overlap * 4:
|
||||
overlap = tile_size // 4
|
||||
if temporal_size < temporal_overlap * 2:
|
||||
temporal_overlap = temporal_overlap // 2
|
||||
temporal_compression = vae.temporal_compression_decode()
|
||||
if temporal_compression is not None:
|
||||
temporal_size = max(2, temporal_size // temporal_compression)
|
||||
temporal_overlap = min(1, temporal_size // 2, temporal_overlap // temporal_compression)
|
||||
else:
|
||||
temporal_size = None
|
||||
temporal_overlap = None
|
||||
|
||||
compression = vae.spacial_compression_decode()
|
||||
images = vae.decode_tiled(samples["samples"], tile_x=tile_size // compression, tile_y=tile_size // compression, overlap=overlap // compression)
|
||||
images = vae.decode_tiled(samples["samples"], tile_x=tile_size // compression, tile_y=tile_size // compression, overlap=overlap // compression, tile_t=temporal_size, overlap_t=temporal_overlap)
|
||||
if len(images.shape) == 5: #Combine batches
|
||||
images = images.reshape(-1, images.shape[-3], images.shape[-2], images.shape[-1])
|
||||
return (images, )
|
||||
@ -326,15 +338,17 @@ class VAEEncodeTiled:
|
||||
return {"required": {"pixels": ("IMAGE", ), "vae": ("VAE", ),
|
||||
"tile_size": ("INT", {"default": 512, "min": 64, "max": 4096, "step": 64}),
|
||||
"overlap": ("INT", {"default": 64, "min": 0, "max": 4096, "step": 32}),
|
||||
"temporal_size": ("INT", {"default": 64, "min": 8, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to encode at a time."}),
|
||||
"temporal_overlap": ("INT", {"default": 8, "min": 4, "max": 4096, "step": 4, "tooltip": "Only used for video VAEs: Amount of frames to overlap."}),
|
||||
}}
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "_for_testing"
|
||||
|
||||
def encode(self, vae, pixels, tile_size, overlap):
|
||||
t = vae.encode_tiled(pixels[:,:,:,:3], tile_x=tile_size, tile_y=tile_size, overlap=overlap)
|
||||
return ({"samples":t}, )
|
||||
def encode(self, vae, pixels, tile_size, overlap, temporal_size=64, temporal_overlap=8):
|
||||
t = vae.encode_tiled(pixels[:,:,:,:3], tile_x=tile_size, tile_y=tile_size, overlap=overlap, tile_t=temporal_size, overlap_t=temporal_overlap)
|
||||
return ({"samples": t}, )
|
||||
|
||||
class VAEEncodeForInpaint:
|
||||
@classmethod
|
||||
@ -1664,7 +1678,6 @@ class LoadImage:
|
||||
|
||||
def load_image(self, image: str) -> tuple[RGBImageBatch, MaskBatch]:
|
||||
image_path = folder_paths.get_annotated_filepath(image)
|
||||
|
||||
output_images = []
|
||||
output_masks = []
|
||||
w, h = None, None
|
||||
|
||||
17
comfy/ops.py
17
comfy/ops.py
@ -308,8 +308,10 @@ def fp8_linear(self, input):
|
||||
tensor_2d = True
|
||||
input = input.unsqueeze(1)
|
||||
|
||||
input_shape = input.shape
|
||||
input_dtype = input.dtype
|
||||
if len(input.shape) == 3:
|
||||
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input.dtype)
|
||||
w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype)
|
||||
w = w.t()
|
||||
|
||||
scale_weight = self.scale_weight
|
||||
@ -321,23 +323,24 @@ def fp8_linear(self, input):
|
||||
|
||||
if scale_input is None:
|
||||
scale_input = torch.ones((), device=input.device, dtype=torch.float32)
|
||||
inn = torch.clamp(input, min=-448, max=448).reshape(-1, input.shape[2]).to(dtype)
|
||||
input = torch.clamp(input, min=-448, max=448, out=input)
|
||||
input = input.reshape(-1, input_shape[2]).to(dtype)
|
||||
else:
|
||||
scale_input = scale_input.to(input.device)
|
||||
inn = (input * (1.0 / scale_input).to(input.dtype)).reshape(-1, input.shape[2]).to(dtype)
|
||||
input = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype)
|
||||
|
||||
if bias is not None:
|
||||
o = torch._scaled_mm(inn, w, out_dtype=input.dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
|
||||
o = torch._scaled_mm(input, w, out_dtype=input_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight)
|
||||
else:
|
||||
o = torch._scaled_mm(inn, w, out_dtype=input.dtype, scale_a=scale_input, scale_b=scale_weight)
|
||||
o = torch._scaled_mm(input, w, out_dtype=input_dtype, scale_a=scale_input, scale_b=scale_weight)
|
||||
|
||||
if isinstance(o, tuple):
|
||||
o = o[0]
|
||||
|
||||
if tensor_2d:
|
||||
return o.reshape(input.shape[0], -1)
|
||||
return o.reshape(input_shape[0], -1)
|
||||
|
||||
return o.reshape((-1, input.shape[1], self.weight.shape[0]))
|
||||
return o.reshape((-1, input_shape[1], self.weight.shape[0]))
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@ -15,7 +15,7 @@ def prepare_noise(latent_image, seed, noise_inds=None):
|
||||
generator = torch.manual_seed(seed)
|
||||
if noise_inds is None:
|
||||
return torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
|
||||
|
||||
|
||||
unique_inds, inverse = np.unique(noise_inds, return_inverse=True)
|
||||
noises = []
|
||||
for i in range(unique_inds[-1]+1):
|
||||
|
||||
48
comfy/sd.py
48
comfy/sd.py
@ -121,7 +121,7 @@ class CLIP:
|
||||
model_management.load_models_gpu([self.patcher], force_full_load=True)
|
||||
self.layer_idx = None
|
||||
self.use_clip_schedule = False
|
||||
logger.debug("CLIP model load device: {}, offload device: {}, current: {}".format(load_device, offload_device, params['device']))
|
||||
logger.info("CLIP model load device: {}, offload device: {}, current: {}, dtype: {}".format(load_device, offload_device, params['device'], dtype))
|
||||
|
||||
def clone(self):
|
||||
n = CLIP(no_init=True)
|
||||
@ -271,6 +271,9 @@ class VAE:
|
||||
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
||||
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||
|
||||
self.downscale_index_formula = None
|
||||
self.upscale_index_formula = None
|
||||
|
||||
if config is None:
|
||||
if "decoder.mid.block_1.mix_factor" in sd:
|
||||
encoder_config = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
||||
@ -350,7 +353,9 @@ class VAE:
|
||||
self.memory_used_decode = lambda shape, dtype: (1000 * shape[2] * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||
self.memory_used_encode = lambda shape, dtype: (1.5 * max(shape[2], 7) * shape[3] * shape[4] * (6 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||
self.upscale_ratio = (lambda a: max(0, a * 6 - 5), 8, 8)
|
||||
self.upscale_index_formula = (6, 8, 8)
|
||||
self.downscale_ratio = (lambda a: max(0, math.floor((a + 5) / 6)), 8, 8)
|
||||
self.downscale_index_formula = (6, 8, 8)
|
||||
self.working_dtypes = [torch.float16, torch.float32]
|
||||
elif "decoder.up_blocks.0.res_blocks.0.conv1.conv.weight" in sd: # lightricks ltxv
|
||||
tensor_conv1 = sd["decoder.up_blocks.0.res_blocks.0.conv1.conv.weight"]
|
||||
@ -365,14 +370,18 @@ class VAE:
|
||||
self.memory_used_decode = lambda shape, dtype: (900 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||
self.memory_used_encode = lambda shape, dtype: (70 * max(shape[2], 7) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
|
||||
self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 32, 32)
|
||||
self.upscale_index_formula = (8, 32, 32)
|
||||
self.downscale_ratio = (lambda a: max(0, math.floor((a + 7) / 8)), 32, 32)
|
||||
self.downscale_index_formula = (8, 32, 32)
|
||||
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||
elif "decoder.conv_in.conv.weight" in sd:
|
||||
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
|
||||
ddconfig["conv3d"] = True
|
||||
ddconfig["time_compress"] = 4
|
||||
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
|
||||
self.upscale_index_formula = (4, 8, 8)
|
||||
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
|
||||
self.downscale_index_formula = (4, 8, 8)
|
||||
self.latent_dim = 3
|
||||
self.latent_channels = ddconfig['z_channels'] = sd["decoder.conv_in.conv.weight"].shape[1]
|
||||
self.first_stage_model = AutoencoderKL(ddconfig=ddconfig, embed_dim=sd['post_quant_conv.weight'].shape[1])
|
||||
@ -405,7 +414,7 @@ class VAE:
|
||||
self.output_device = model_management.intermediate_device()
|
||||
|
||||
self.patcher = model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
|
||||
logger.debug("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
|
||||
logger.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
|
||||
|
||||
def vae_encode_crop_pixels(self, pixels):
|
||||
downscale_ratio = self.spacial_compression_encode()
|
||||
@ -438,7 +447,7 @@ class VAE:
|
||||
|
||||
def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)):
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
|
||||
return self.process_output(utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device))
|
||||
return self.process_output(utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device))
|
||||
|
||||
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap=64):
|
||||
steps = pixel_samples.shape[0] * utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x, tile_y, overlap)
|
||||
@ -459,7 +468,7 @@ class VAE:
|
||||
|
||||
def encode_tiled_3d(self, samples, tile_t=9999, tile_x=512, tile_y=512, overlap=(1, 64, 64)):
|
||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
|
||||
return utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, output_device=self.output_device)
|
||||
return utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device)
|
||||
|
||||
def decode(self, samples_in):
|
||||
pixel_samples = None
|
||||
@ -491,7 +500,7 @@ class VAE:
|
||||
pixel_samples = pixel_samples.to(self.output_device).movedim(1, -1)
|
||||
return pixel_samples
|
||||
|
||||
def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None):
|
||||
def decode_tiled(self, samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
|
||||
memory_used = self.memory_used_decode(samples.shape, self.vae_dtype) # TODO: calculate mem required for tile
|
||||
load_models_gpu([self.patcher], memory_required=memory_used)
|
||||
dims = samples.ndim - 2
|
||||
@ -509,6 +518,13 @@ class VAE:
|
||||
elif dims == 2:
|
||||
output = self.decode_tiled_(samples, **args)
|
||||
elif dims == 3:
|
||||
if overlap_t is None:
|
||||
args["overlap"] = (1, overlap, overlap)
|
||||
else:
|
||||
args["overlap"] = (max(1, overlap_t), overlap, overlap)
|
||||
if tile_t is not None:
|
||||
args["tile_t"] = max(2, tile_t)
|
||||
|
||||
output = self.decode_tiled_3d(samples, **args)
|
||||
else:
|
||||
raise ValueError(f"invalid dims={dims}")
|
||||
@ -546,7 +562,7 @@ class VAE:
|
||||
|
||||
return samples
|
||||
|
||||
def encode_tiled(self, pixel_samples, tile_x=None, tile_y=None, overlap=None):
|
||||
def encode_tiled(self, pixel_samples, tile_x=None, tile_y=None, overlap=None, tile_t=None, overlap_t=None):
|
||||
pixel_samples = self.vae_encode_crop_pixels(pixel_samples)
|
||||
dims = self.latent_dim
|
||||
pixel_samples = pixel_samples.movedim(-1, 1)
|
||||
@ -571,7 +587,20 @@ class VAE:
|
||||
elif dims == 2:
|
||||
samples = self.encode_tiled_(pixel_samples, **args)
|
||||
elif dims == 3:
|
||||
samples = self.encode_tiled_3d(pixel_samples, **args)
|
||||
if tile_t is not None:
|
||||
tile_t_latent = max(2, self.downscale_ratio[0](tile_t))
|
||||
else:
|
||||
tile_t_latent = 9999
|
||||
args["tile_t"] = self.upscale_ratio[0](tile_t_latent)
|
||||
|
||||
if overlap_t is None:
|
||||
args["overlap"] = (1, overlap, overlap)
|
||||
else:
|
||||
args["overlap"] = (self.upscale_ratio[0](max(1, min(tile_t_latent // 2, self.downscale_ratio[0](overlap_t)))), overlap, overlap)
|
||||
maximum = pixel_samples.shape[2]
|
||||
maximum = self.upscale_ratio[0](self.downscale_ratio[0](maximum))
|
||||
|
||||
samples = self.encode_tiled_3d(pixel_samples[:,:,:maximum], **args)
|
||||
else:
|
||||
raise ValueError(f"unsupported values dim {dims}")
|
||||
|
||||
@ -592,6 +621,11 @@ class VAE:
|
||||
except:
|
||||
return self.downscale_ratio
|
||||
|
||||
def temporal_compression_decode(self):
|
||||
try:
|
||||
return round(self.upscale_ratio[0](8192) / 8192)
|
||||
except:
|
||||
return None
|
||||
|
||||
class StyleModel:
|
||||
def __init__(self, model, device="cpu"):
|
||||
|
||||
@ -627,6 +627,8 @@ class PixArtAlpha(supported_models_base.BASE):
|
||||
unet_extra_config = {}
|
||||
latent_format = latent_formats.SD15
|
||||
|
||||
memory_usage_factor = 0.5
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
@ -663,6 +665,8 @@ class HunyuanDiT(supported_models_base.BASE):
|
||||
|
||||
latent_format = latent_formats.SDXL
|
||||
|
||||
memory_usage_factor = 1.3
|
||||
|
||||
vae_key_prefix = ["vae."]
|
||||
text_encoder_key_prefix = ["text_encoders."]
|
||||
|
||||
|
||||
@ -912,7 +912,7 @@ def get_tiled_scale_steps(width, height, tile_x, tile_y, overlap):
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_amount=4, out_channels=3, output_device="cpu", downscale=False, pbar=None):
|
||||
def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_amount=4, out_channels=3, output_device="cpu", downscale=False, index_formulas=None, pbar=None):
|
||||
dims = len(tile)
|
||||
|
||||
if not (isinstance(upscale_amount, (tuple, list))):
|
||||
@ -921,6 +921,12 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
|
||||
if not (isinstance(overlap, (tuple, list))):
|
||||
overlap = [overlap] * dims
|
||||
|
||||
if index_formulas is None:
|
||||
index_formulas = upscale_amount
|
||||
|
||||
if not (isinstance(index_formulas, (tuple, list))):
|
||||
index_formulas = [index_formulas] * dims
|
||||
|
||||
def get_upscale(dim, val):
|
||||
up = upscale_amount[dim]
|
||||
if callable(up):
|
||||
@ -935,10 +941,26 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
|
||||
else:
|
||||
return val / up
|
||||
|
||||
def get_upscale_pos(dim, val):
|
||||
up = index_formulas[dim]
|
||||
if callable(up):
|
||||
return up(val)
|
||||
else:
|
||||
return up * val
|
||||
|
||||
def get_downscale_pos(dim, val):
|
||||
up = index_formulas[dim]
|
||||
if callable(up):
|
||||
return up(val)
|
||||
else:
|
||||
return val / up
|
||||
|
||||
if downscale:
|
||||
get_scale = get_downscale
|
||||
get_pos = get_downscale_pos
|
||||
else:
|
||||
get_scale = get_upscale
|
||||
get_pos = get_upscale_pos
|
||||
|
||||
def mult_list_upscale(a):
|
||||
out = []
|
||||
@ -970,7 +992,7 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
|
||||
pos = max(0, min(s.shape[d + 2] - overlap[d], it[d]))
|
||||
l = min(tile[d], s.shape[d + 2] - pos)
|
||||
s_in = s_in.narrow(d + 2, pos, l)
|
||||
upscaled.append(round(get_scale(d, pos)))
|
||||
upscaled.append(round(get_pos(d, pos)))
|
||||
|
||||
ps = function(s_in).to(output_device)
|
||||
mask = torch.ones_like(ps)
|
||||
|
||||
@ -306,7 +306,7 @@ class FeatherMask:
|
||||
output[:, -y, :] *= feather_rate
|
||||
|
||||
return (output,)
|
||||
|
||||
|
||||
class GrowMask:
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls):
|
||||
@ -317,7 +317,7 @@ class GrowMask:
|
||||
"tapered_corners": ("BOOLEAN", {"default": True}),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
CATEGORY = "mask"
|
||||
|
||||
RETURN_TYPES = ("MASK",)
|
||||
|
||||
@ -40,7 +40,7 @@ class LatentRebatch:
|
||||
return slices, indexable[num * batch_size:]
|
||||
else:
|
||||
return slices, None
|
||||
|
||||
|
||||
@staticmethod
|
||||
def slice_batch(batch, num, batch_size):
|
||||
result = [LatentRebatch.get_slices(x, num, batch_size) for x in batch]
|
||||
@ -81,7 +81,7 @@ class LatentRebatch:
|
||||
if current_batch[0].shape[0] > batch_size:
|
||||
num = current_batch[0].shape[0] // batch_size
|
||||
sliced, remainder = self.slice_batch(current_batch, num, batch_size)
|
||||
|
||||
|
||||
for i in range(num):
|
||||
output_list.append({'samples': sliced[0][i], 'noise_mask': sliced[1][i], 'batch_index': sliced[2][i]})
|
||||
|
||||
|
||||
@ -40,9 +40,8 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor,
|
||||
return do_nothing, do_nothing
|
||||
|
||||
gather = mps_gather_workaround if metric.device.type == "mps" else torch.gather
|
||||
|
||||
|
||||
with torch.no_grad():
|
||||
|
||||
hsy, wsx = h // sy, w // sx
|
||||
|
||||
# For each sy by sx kernel, randomly assign one token to be dst and the rest src
|
||||
@ -50,7 +49,7 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor,
|
||||
rand_idx = torch.zeros(hsy, wsx, 1, device=metric.device, dtype=torch.int64)
|
||||
else:
|
||||
rand_idx = torch.randint(sy*sx, size=(hsy, wsx, 1), device=metric.device)
|
||||
|
||||
|
||||
# The image might not divide sx and sy, so we need to work on a view of the top left if the idx buffer instead
|
||||
idx_buffer_view = torch.zeros(hsy, wsx, sy*sx, device=metric.device, dtype=torch.int64)
|
||||
idx_buffer_view.scatter_(dim=2, index=rand_idx, src=-torch.ones_like(rand_idx, dtype=rand_idx.dtype))
|
||||
@ -99,7 +98,7 @@ def bipartite_soft_matching_random2d(metric: torch.Tensor,
|
||||
def merge(x: torch.Tensor, mode="mean") -> torch.Tensor:
|
||||
src, dst = split(x)
|
||||
n, t1, c = src.shape
|
||||
|
||||
|
||||
unm = gather(src, dim=-2, index=unm_idx.expand(n, t1 - r, c))
|
||||
src = gather(src, dim=-2, index=src_idx.expand(n, r, c))
|
||||
dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user