mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-21 23:39:35 +08:00
Merge remote-tracking branch 'upstream/master' into qwen35
This commit is contained in:
commit
53c83214ef
11
README.md
11
README.md
@ -38,6 +38,8 @@ ComfyUI lets you design and execute advanced stable diffusion pipelines using a
|
|||||||
|
|
||||||
## Get Started
|
## Get Started
|
||||||
|
|
||||||
|
### Local
|
||||||
|
|
||||||
#### [Desktop Application](https://www.comfy.org/download)
|
#### [Desktop Application](https://www.comfy.org/download)
|
||||||
- The easiest way to get started.
|
- The easiest way to get started.
|
||||||
- Available on Windows & macOS.
|
- Available on Windows & macOS.
|
||||||
@ -49,8 +51,13 @@ ComfyUI lets you design and execute advanced stable diffusion pipelines using a
|
|||||||
#### [Manual Install](#manual-install-windows-linux)
|
#### [Manual Install](#manual-install-windows-linux)
|
||||||
Supports all operating systems and GPU types (NVIDIA, AMD, Intel, Apple Silicon, Ascend).
|
Supports all operating systems and GPU types (NVIDIA, AMD, Intel, Apple Silicon, Ascend).
|
||||||
|
|
||||||
## [Examples](https://comfyanonymous.github.io/ComfyUI_examples/)
|
### Cloud
|
||||||
See what ComfyUI can do with the [example workflows](https://comfyanonymous.github.io/ComfyUI_examples/).
|
|
||||||
|
#### [Comfy Cloud](https://www.comfy.org/cloud)
|
||||||
|
- Our official paid cloud version for those who can't afford local hardware.
|
||||||
|
|
||||||
|
## Examples
|
||||||
|
See what ComfyUI can do with the [newer template workflows](https://comfy.org/workflows) or old [example workflows](https://comfyanonymous.github.io/ComfyUI_examples/).
|
||||||
|
|
||||||
## Features
|
## Features
|
||||||
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
|
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
|
||||||
|
|||||||
@ -83,6 +83,8 @@ fpte_group.add_argument("--fp16-text-enc", action="store_true", help="Store text
|
|||||||
fpte_group.add_argument("--fp32-text-enc", action="store_true", help="Store text encoder weights in fp32.")
|
fpte_group.add_argument("--fp32-text-enc", action="store_true", help="Store text encoder weights in fp32.")
|
||||||
fpte_group.add_argument("--bf16-text-enc", action="store_true", help="Store text encoder weights in bf16.")
|
fpte_group.add_argument("--bf16-text-enc", action="store_true", help="Store text encoder weights in bf16.")
|
||||||
|
|
||||||
|
parser.add_argument("--fp16-intermediates", action="store_true", help="Experimental: Use fp16 for intermediate tensors between nodes instead of fp32.")
|
||||||
|
|
||||||
parser.add_argument("--force-channels-last", action="store_true", help="Force channels last format when inferencing the models.")
|
parser.add_argument("--force-channels-last", action="store_true", help="Force channels last format when inferencing the models.")
|
||||||
|
|
||||||
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
|
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
|
||||||
|
|||||||
@ -209,3 +209,39 @@ def stochastic_round_quantize_nvfp4_by_block(x, per_tensor_scale, pad_16x, seed=
|
|||||||
output_block[i:i + slice_size].copy_(block)
|
output_block[i:i + slice_size].copy_(block)
|
||||||
|
|
||||||
return output_fp4, to_blocked(output_block, flatten=False)
|
return output_fp4, to_blocked(output_block, flatten=False)
|
||||||
|
|
||||||
|
|
||||||
|
def stochastic_round_quantize_mxfp8_by_block(x, pad_32x, seed=0):
|
||||||
|
def roundup(x_val, multiple):
|
||||||
|
return ((x_val + multiple - 1) // multiple) * multiple
|
||||||
|
|
||||||
|
if pad_32x:
|
||||||
|
rows, cols = x.shape
|
||||||
|
padded_rows = roundup(rows, 32)
|
||||||
|
padded_cols = roundup(cols, 32)
|
||||||
|
if padded_rows != rows or padded_cols != cols:
|
||||||
|
x = torch.nn.functional.pad(x, (0, padded_cols - cols, 0, padded_rows - rows))
|
||||||
|
|
||||||
|
F8_E4M3_MAX = 448.0
|
||||||
|
E8M0_BIAS = 127
|
||||||
|
BLOCK_SIZE = 32
|
||||||
|
|
||||||
|
rows, cols = x.shape
|
||||||
|
x_blocked = x.reshape(rows, -1, BLOCK_SIZE)
|
||||||
|
max_abs = torch.amax(torch.abs(x_blocked), dim=-1)
|
||||||
|
|
||||||
|
# E8M0 block scales (power-of-2 exponents)
|
||||||
|
scale_needed = torch.clamp(max_abs.float() / F8_E4M3_MAX, min=2**(-127))
|
||||||
|
exp_biased = torch.clamp(torch.ceil(torch.log2(scale_needed)).to(torch.int32) + E8M0_BIAS, 0, 254)
|
||||||
|
block_scales_e8m0 = exp_biased.to(torch.uint8)
|
||||||
|
|
||||||
|
zero_mask = (max_abs == 0)
|
||||||
|
block_scales_f32 = (block_scales_e8m0.to(torch.int32) << 23).view(torch.float32)
|
||||||
|
block_scales_f32 = torch.where(zero_mask, torch.ones_like(block_scales_f32), block_scales_f32)
|
||||||
|
|
||||||
|
# Scale per-block then stochastic round
|
||||||
|
data_scaled = (x_blocked.float() / block_scales_f32.unsqueeze(-1)).reshape(rows, cols)
|
||||||
|
output_fp8 = stochastic_rounding(data_scaled, torch.float8_e4m3fn, seed=seed)
|
||||||
|
|
||||||
|
block_scales_e8m0 = torch.where(zero_mask, torch.zeros_like(block_scales_e8m0), block_scales_e8m0)
|
||||||
|
return output_fp8, to_blocked(block_scales_e8m0, flatten=False).view(torch.float8_e8m0fnu)
|
||||||
|
|||||||
@ -11,6 +11,7 @@ from .causal_conv3d import CausalConv3d
|
|||||||
from .pixel_norm import PixelNorm
|
from .pixel_norm import PixelNorm
|
||||||
from ..model import PixArtAlphaCombinedTimestepSizeEmbeddings
|
from ..model import PixArtAlphaCombinedTimestepSizeEmbeddings
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
|
import comfy.model_management
|
||||||
from comfy.ldm.modules.diffusionmodules.model import torch_cat_if_needed
|
from comfy.ldm.modules.diffusionmodules.model import torch_cat_if_needed
|
||||||
|
|
||||||
ops = comfy.ops.disable_weight_init
|
ops = comfy.ops.disable_weight_init
|
||||||
@ -536,7 +537,7 @@ class Decoder(nn.Module):
|
|||||||
mark_conv3d_ended(self.conv_out)
|
mark_conv3d_ended(self.conv_out)
|
||||||
sample = self.conv_out(sample, causal=self.causal)
|
sample = self.conv_out(sample, causal=self.causal)
|
||||||
if sample is not None and sample.shape[2] > 0:
|
if sample is not None and sample.shape[2] > 0:
|
||||||
output.append(sample)
|
output.append(sample.to(comfy.model_management.intermediate_device()))
|
||||||
return
|
return
|
||||||
|
|
||||||
up_block = self.up_blocks[idx]
|
up_block = self.up_blocks[idx]
|
||||||
|
|||||||
@ -1,9 +1,68 @@
|
|||||||
import math
|
import math
|
||||||
|
import ctypes
|
||||||
|
import threading
|
||||||
|
import dataclasses
|
||||||
import torch
|
import torch
|
||||||
from typing import NamedTuple
|
from typing import NamedTuple
|
||||||
|
|
||||||
from comfy.quant_ops import QuantizedTensor
|
from comfy.quant_ops import QuantizedTensor
|
||||||
|
|
||||||
|
|
||||||
|
class TensorFileSlice(NamedTuple):
|
||||||
|
file_ref: object
|
||||||
|
thread_id: int
|
||||||
|
offset: int
|
||||||
|
size: int
|
||||||
|
|
||||||
|
|
||||||
|
def read_tensor_file_slice_into(tensor, destination):
|
||||||
|
|
||||||
|
if isinstance(tensor, QuantizedTensor):
|
||||||
|
if not isinstance(destination, QuantizedTensor):
|
||||||
|
return False
|
||||||
|
if tensor._layout_cls != destination._layout_cls:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not read_tensor_file_slice_into(tensor._qdata, destination._qdata):
|
||||||
|
return False
|
||||||
|
|
||||||
|
dst_orig_dtype = destination._params.orig_dtype
|
||||||
|
destination._params.copy_from(tensor._params, non_blocking=False)
|
||||||
|
destination._params = dataclasses.replace(destination._params, orig_dtype=dst_orig_dtype)
|
||||||
|
return True
|
||||||
|
|
||||||
|
info = getattr(tensor.untyped_storage(), "_comfy_tensor_file_slice", None)
|
||||||
|
if info is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
file_obj = info.file_ref
|
||||||
|
if (destination.device.type != "cpu"
|
||||||
|
or file_obj is None
|
||||||
|
or threading.get_ident() != info.thread_id
|
||||||
|
or destination.numel() * destination.element_size() < info.size):
|
||||||
|
return False
|
||||||
|
|
||||||
|
if info.size == 0:
|
||||||
|
return True
|
||||||
|
|
||||||
|
buf_type = ctypes.c_ubyte * info.size
|
||||||
|
view = memoryview(buf_type.from_address(destination.data_ptr()))
|
||||||
|
|
||||||
|
try:
|
||||||
|
file_obj.seek(info.offset)
|
||||||
|
done = 0
|
||||||
|
while done < info.size:
|
||||||
|
try:
|
||||||
|
n = file_obj.readinto(view[done:])
|
||||||
|
except OSError:
|
||||||
|
return False
|
||||||
|
if n <= 0:
|
||||||
|
return False
|
||||||
|
done += n
|
||||||
|
return True
|
||||||
|
finally:
|
||||||
|
view.release()
|
||||||
|
|
||||||
class TensorGeometry(NamedTuple):
|
class TensorGeometry(NamedTuple):
|
||||||
shape: any
|
shape: any
|
||||||
dtype: torch.dtype
|
dtype: torch.dtype
|
||||||
|
|||||||
@ -505,6 +505,28 @@ def module_size(module):
|
|||||||
module_mem += t.nbytes
|
module_mem += t.nbytes
|
||||||
return module_mem
|
return module_mem
|
||||||
|
|
||||||
|
def module_mmap_residency(module, free=False):
|
||||||
|
mmap_touched_mem = 0
|
||||||
|
module_mem = 0
|
||||||
|
bounced_mmaps = set()
|
||||||
|
sd = module.state_dict()
|
||||||
|
for k in sd:
|
||||||
|
t = sd[k]
|
||||||
|
module_mem += t.nbytes
|
||||||
|
storage = t._qdata.untyped_storage() if isinstance(t, comfy.quant_ops.QuantizedTensor) else t.untyped_storage()
|
||||||
|
if not getattr(storage, "_comfy_tensor_mmap_touched", False):
|
||||||
|
continue
|
||||||
|
mmap_touched_mem += t.nbytes
|
||||||
|
if not free:
|
||||||
|
continue
|
||||||
|
storage._comfy_tensor_mmap_touched = False
|
||||||
|
mmap_obj = storage._comfy_tensor_mmap_refs[0]
|
||||||
|
if mmap_obj in bounced_mmaps:
|
||||||
|
continue
|
||||||
|
mmap_obj.bounce()
|
||||||
|
bounced_mmaps.add(mmap_obj)
|
||||||
|
return mmap_touched_mem, module_mem
|
||||||
|
|
||||||
class LoadedModel:
|
class LoadedModel:
|
||||||
def __init__(self, model):
|
def __init__(self, model):
|
||||||
self._set_model(model)
|
self._set_model(model)
|
||||||
@ -532,6 +554,9 @@ class LoadedModel:
|
|||||||
def model_memory(self):
|
def model_memory(self):
|
||||||
return self.model.model_size()
|
return self.model.model_size()
|
||||||
|
|
||||||
|
def model_mmap_residency(self, free=False):
|
||||||
|
return self.model.model_mmap_residency(free=free)
|
||||||
|
|
||||||
def model_loaded_memory(self):
|
def model_loaded_memory(self):
|
||||||
return self.model.loaded_size()
|
return self.model.loaded_size()
|
||||||
|
|
||||||
@ -633,7 +658,7 @@ def extra_reserved_memory():
|
|||||||
def minimum_inference_memory():
|
def minimum_inference_memory():
|
||||||
return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory()
|
return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory()
|
||||||
|
|
||||||
def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_required=0):
|
def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins_required=0, ram_required=0):
|
||||||
cleanup_models_gc()
|
cleanup_models_gc()
|
||||||
unloaded_model = []
|
unloaded_model = []
|
||||||
can_unload = []
|
can_unload = []
|
||||||
@ -646,13 +671,14 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_
|
|||||||
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
|
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
|
||||||
shift_model.currently_used = False
|
shift_model.currently_used = False
|
||||||
|
|
||||||
for x in sorted(can_unload):
|
can_unload_sorted = sorted(can_unload)
|
||||||
|
for x in can_unload_sorted:
|
||||||
i = x[-1]
|
i = x[-1]
|
||||||
memory_to_free = 1e32
|
memory_to_free = 1e32
|
||||||
ram_to_free = 1e32
|
pins_to_free = 1e32
|
||||||
if not DISABLE_SMART_MEMORY:
|
if not DISABLE_SMART_MEMORY:
|
||||||
memory_to_free = memory_required - get_free_memory(device)
|
memory_to_free = memory_required - get_free_memory(device)
|
||||||
ram_to_free = ram_required - get_free_ram()
|
pins_to_free = pins_required - get_free_ram()
|
||||||
if current_loaded_models[i].model.is_dynamic() and for_dynamic:
|
if current_loaded_models[i].model.is_dynamic() and for_dynamic:
|
||||||
#don't actually unload dynamic models for the sake of other dynamic models
|
#don't actually unload dynamic models for the sake of other dynamic models
|
||||||
#as that works on-demand.
|
#as that works on-demand.
|
||||||
@ -661,9 +687,18 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_
|
|||||||
if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free):
|
if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free):
|
||||||
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
||||||
unloaded_model.append(i)
|
unloaded_model.append(i)
|
||||||
if ram_to_free > 0:
|
if pins_to_free > 0:
|
||||||
|
logging.debug(f"PIN Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
||||||
|
current_loaded_models[i].model.partially_unload_ram(pins_to_free)
|
||||||
|
|
||||||
|
for x in can_unload_sorted:
|
||||||
|
i = x[-1]
|
||||||
|
ram_to_free = ram_required - psutil.virtual_memory().available
|
||||||
|
if ram_to_free <= 0 and i not in unloaded_model:
|
||||||
|
continue
|
||||||
|
resident_memory, _ = current_loaded_models[i].model_mmap_residency(free=True)
|
||||||
|
if resident_memory > 0:
|
||||||
logging.debug(f"RAM Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
logging.debug(f"RAM Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
||||||
current_loaded_models[i].model.partially_unload_ram(ram_to_free)
|
|
||||||
|
|
||||||
for i in sorted(unloaded_model, reverse=True):
|
for i in sorted(unloaded_model, reverse=True):
|
||||||
unloaded_models.append(current_loaded_models.pop(i))
|
unloaded_models.append(current_loaded_models.pop(i))
|
||||||
@ -729,17 +764,27 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
|||||||
|
|
||||||
|
|
||||||
total_memory_required = {}
|
total_memory_required = {}
|
||||||
|
total_pins_required = {}
|
||||||
total_ram_required = {}
|
total_ram_required = {}
|
||||||
for loaded_model in models_to_load:
|
for loaded_model in models_to_load:
|
||||||
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
|
device = loaded_model.device
|
||||||
#x2, one to make sure the OS can fit the model for loading in disk cache, and for us to do any pinning we
|
total_memory_required[device] = total_memory_required.get(device, 0) + loaded_model.model_memory_required(device)
|
||||||
#want to do.
|
resident_memory, model_memory = loaded_model.model_mmap_residency()
|
||||||
#FIXME: This should subtract off the to_load current pin consumption.
|
pinned_memory = loaded_model.model.pinned_memory_size()
|
||||||
total_ram_required[loaded_model.device] = total_ram_required.get(loaded_model.device, 0) + loaded_model.model_memory() * 2
|
#FIXME: This can over-free the pins as it budgets to pin the entire model. We should
|
||||||
|
#make this JIT to keep as much pinned as possible.
|
||||||
|
pins_required = model_memory - pinned_memory
|
||||||
|
ram_required = model_memory - resident_memory
|
||||||
|
total_pins_required[device] = total_pins_required.get(device, 0) + pins_required
|
||||||
|
total_ram_required[device] = total_ram_required.get(device, 0) + ram_required
|
||||||
|
|
||||||
for device in total_memory_required:
|
for device in total_memory_required:
|
||||||
if device != torch.device("cpu"):
|
if device != torch.device("cpu"):
|
||||||
free_memory(total_memory_required[device] * 1.1 + extra_mem, device, for_dynamic=free_for_dynamic, ram_required=total_ram_required[device])
|
free_memory(total_memory_required[device] * 1.1 + extra_mem,
|
||||||
|
device,
|
||||||
|
for_dynamic=free_for_dynamic,
|
||||||
|
pins_required=total_pins_required[device],
|
||||||
|
ram_required=total_ram_required[device])
|
||||||
|
|
||||||
for device in total_memory_required:
|
for device in total_memory_required:
|
||||||
if device != torch.device("cpu"):
|
if device != torch.device("cpu"):
|
||||||
@ -1005,6 +1050,12 @@ def intermediate_device():
|
|||||||
else:
|
else:
|
||||||
return torch.device("cpu")
|
return torch.device("cpu")
|
||||||
|
|
||||||
|
def intermediate_dtype():
|
||||||
|
if args.fp16_intermediates:
|
||||||
|
return torch.float16
|
||||||
|
else:
|
||||||
|
return torch.float32
|
||||||
|
|
||||||
def vae_device():
|
def vae_device():
|
||||||
if args.cpu_vae:
|
if args.cpu_vae:
|
||||||
return torch.device("cpu")
|
return torch.device("cpu")
|
||||||
@ -1225,6 +1276,11 @@ def cast_to_gathered(tensors, r, non_blocking=False, stream=None):
|
|||||||
dest_view = dest_views.pop(0)
|
dest_view = dest_views.pop(0)
|
||||||
if tensor is None:
|
if tensor is None:
|
||||||
continue
|
continue
|
||||||
|
if comfy.memory_management.read_tensor_file_slice_into(tensor, dest_view):
|
||||||
|
continue
|
||||||
|
storage = tensor._qdata.untyped_storage() if isinstance(tensor, comfy.quant_ops.QuantizedTensor) else tensor.untyped_storage()
|
||||||
|
if hasattr(storage, "_comfy_tensor_mmap_touched"):
|
||||||
|
storage._comfy_tensor_mmap_touched = True
|
||||||
dest_view.copy_(tensor, non_blocking=non_blocking)
|
dest_view.copy_(tensor, non_blocking=non_blocking)
|
||||||
|
|
||||||
|
|
||||||
@ -1662,6 +1718,19 @@ def supports_nvfp4_compute(device=None):
|
|||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def supports_mxfp8_compute(device=None):
|
||||||
|
if not is_nvidia():
|
||||||
|
return False
|
||||||
|
|
||||||
|
if torch_version_numeric < (2, 10):
|
||||||
|
return False
|
||||||
|
|
||||||
|
props = torch.cuda.get_device_properties(device)
|
||||||
|
if props.major < 10:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
def extended_fp16_support():
|
def extended_fp16_support():
|
||||||
# TODO: check why some models work with fp16 on newer torch versions but not on older
|
# TODO: check why some models work with fp16 on newer torch versions but not on older
|
||||||
if torch_version_numeric < (2, 7):
|
if torch_version_numeric < (2, 7):
|
||||||
|
|||||||
@ -297,6 +297,9 @@ class ModelPatcher:
|
|||||||
self.size = comfy.model_management.module_size(self.model)
|
self.size = comfy.model_management.module_size(self.model)
|
||||||
return self.size
|
return self.size
|
||||||
|
|
||||||
|
def model_mmap_residency(self, free=False):
|
||||||
|
return comfy.model_management.module_mmap_residency(self.model, free=free)
|
||||||
|
|
||||||
def get_ram_usage(self):
|
def get_ram_usage(self):
|
||||||
return self.model_size()
|
return self.model_size()
|
||||||
|
|
||||||
@ -1063,6 +1066,10 @@ class ModelPatcher:
|
|||||||
|
|
||||||
return self.model.model_loaded_weight_memory - current_used
|
return self.model.model_loaded_weight_memory - current_used
|
||||||
|
|
||||||
|
def pinned_memory_size(self):
|
||||||
|
# Pinned memory pressure tracking is only implemented for DynamicVram loading
|
||||||
|
return 0
|
||||||
|
|
||||||
def partially_unload_ram(self, ram_to_unload):
|
def partially_unload_ram(self, ram_to_unload):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -1653,6 +1660,16 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
|
|
||||||
return freed
|
return freed
|
||||||
|
|
||||||
|
def pinned_memory_size(self):
|
||||||
|
total = 0
|
||||||
|
loading = self._load_list(for_dynamic=True)
|
||||||
|
for x in loading:
|
||||||
|
_, _, _, _, m, _ = x
|
||||||
|
pin = comfy.pinned_memory.get_pin(m)
|
||||||
|
if pin is not None:
|
||||||
|
total += pin.numel() * pin.element_size()
|
||||||
|
return total
|
||||||
|
|
||||||
def partially_unload_ram(self, ram_to_unload):
|
def partially_unload_ram(self, ram_to_unload):
|
||||||
loading = self._load_list(for_dynamic=True, default_device=self.offload_device)
|
loading = self._load_list(for_dynamic=True, default_device=self.offload_device)
|
||||||
for x in loading:
|
for x in loading:
|
||||||
|
|||||||
121
comfy/ops.py
121
comfy/ops.py
@ -306,6 +306,33 @@ class CastWeightBiasOp:
|
|||||||
bias_function = []
|
bias_function = []
|
||||||
|
|
||||||
class disable_weight_init:
|
class disable_weight_init:
|
||||||
|
@staticmethod
|
||||||
|
def _lazy_load_from_state_dict(module, state_dict, prefix, local_metadata,
|
||||||
|
missing_keys, unexpected_keys, weight_shape,
|
||||||
|
bias_shape=None):
|
||||||
|
assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False)
|
||||||
|
prefix_len = len(prefix)
|
||||||
|
for k, v in state_dict.items():
|
||||||
|
key = k[prefix_len:]
|
||||||
|
if key == "weight":
|
||||||
|
if not assign_to_params_buffers:
|
||||||
|
v = v.clone()
|
||||||
|
module.weight = torch.nn.Parameter(v, requires_grad=False)
|
||||||
|
elif bias_shape is not None and key == "bias" and v is not None:
|
||||||
|
if not assign_to_params_buffers:
|
||||||
|
v = v.clone()
|
||||||
|
module.bias = torch.nn.Parameter(v, requires_grad=False)
|
||||||
|
else:
|
||||||
|
unexpected_keys.append(k)
|
||||||
|
|
||||||
|
if module.weight is None:
|
||||||
|
module.weight = torch.nn.Parameter(torch.zeros(weight_shape), requires_grad=False)
|
||||||
|
missing_keys.append(prefix + "weight")
|
||||||
|
|
||||||
|
if bias_shape is not None and module.bias is None and getattr(module, "comfy_need_lazy_init_bias", False):
|
||||||
|
module.bias = torch.nn.Parameter(torch.zeros(bias_shape), requires_grad=False)
|
||||||
|
missing_keys.append(prefix + "bias")
|
||||||
|
|
||||||
class Linear(torch.nn.Linear, CastWeightBiasOp):
|
class Linear(torch.nn.Linear, CastWeightBiasOp):
|
||||||
|
|
||||||
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
|
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
|
||||||
@ -333,29 +360,16 @@ class disable_weight_init:
|
|||||||
if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled:
|
if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled:
|
||||||
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
|
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
|
||||||
missing_keys, unexpected_keys, error_msgs)
|
missing_keys, unexpected_keys, error_msgs)
|
||||||
assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False)
|
disable_weight_init._lazy_load_from_state_dict(
|
||||||
prefix_len = len(prefix)
|
self,
|
||||||
for k,v in state_dict.items():
|
state_dict,
|
||||||
if k[prefix_len:] == "weight":
|
prefix,
|
||||||
if not assign_to_params_buffers:
|
local_metadata,
|
||||||
v = v.clone()
|
missing_keys,
|
||||||
self.weight = torch.nn.Parameter(v, requires_grad=False)
|
unexpected_keys,
|
||||||
elif k[prefix_len:] == "bias" and v is not None:
|
weight_shape=(self.in_features, self.out_features),
|
||||||
if not assign_to_params_buffers:
|
bias_shape=(self.out_features,),
|
||||||
v = v.clone()
|
)
|
||||||
self.bias = torch.nn.Parameter(v, requires_grad=False)
|
|
||||||
else:
|
|
||||||
unexpected_keys.append(k)
|
|
||||||
|
|
||||||
#Reconcile default construction of the weight if its missing.
|
|
||||||
if self.weight is None:
|
|
||||||
v = torch.zeros(self.in_features, self.out_features)
|
|
||||||
self.weight = torch.nn.Parameter(v, requires_grad=False)
|
|
||||||
missing_keys.append(prefix+"weight")
|
|
||||||
if self.bias is None and self.comfy_need_lazy_init_bias:
|
|
||||||
v = torch.zeros(self.out_features,)
|
|
||||||
self.bias = torch.nn.Parameter(v, requires_grad=False)
|
|
||||||
missing_keys.append(prefix+"bias")
|
|
||||||
|
|
||||||
|
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
@ -547,6 +561,48 @@ class disable_weight_init:
|
|||||||
return super().forward(*args, **kwargs)
|
return super().forward(*args, **kwargs)
|
||||||
|
|
||||||
class Embedding(torch.nn.Embedding, CastWeightBiasOp):
|
class Embedding(torch.nn.Embedding, CastWeightBiasOp):
|
||||||
|
def __init__(self, num_embeddings, embedding_dim, padding_idx=None, max_norm=None,
|
||||||
|
norm_type=2.0, scale_grad_by_freq=False, sparse=False, _weight=None,
|
||||||
|
_freeze=False, device=None, dtype=None):
|
||||||
|
if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled:
|
||||||
|
super().__init__(num_embeddings, embedding_dim, padding_idx, max_norm,
|
||||||
|
norm_type, scale_grad_by_freq, sparse, _weight,
|
||||||
|
_freeze, device, dtype)
|
||||||
|
return
|
||||||
|
|
||||||
|
torch.nn.Module.__init__(self)
|
||||||
|
self.num_embeddings = num_embeddings
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
self.padding_idx = padding_idx
|
||||||
|
self.max_norm = max_norm
|
||||||
|
self.norm_type = norm_type
|
||||||
|
self.scale_grad_by_freq = scale_grad_by_freq
|
||||||
|
self.sparse = sparse
|
||||||
|
# Keep shape/dtype visible for module introspection without reserving storage.
|
||||||
|
embedding_dtype = dtype if dtype is not None else torch.get_default_dtype()
|
||||||
|
self.weight = torch.nn.Parameter(
|
||||||
|
torch.empty((num_embeddings, embedding_dim), device="meta", dtype=embedding_dtype),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
self.bias = None
|
||||||
|
self.weight_comfy_model_dtype = dtype
|
||||||
|
|
||||||
|
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
||||||
|
strict, missing_keys, unexpected_keys, error_msgs):
|
||||||
|
|
||||||
|
if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled:
|
||||||
|
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
|
||||||
|
missing_keys, unexpected_keys, error_msgs)
|
||||||
|
disable_weight_init._lazy_load_from_state_dict(
|
||||||
|
self,
|
||||||
|
state_dict,
|
||||||
|
prefix,
|
||||||
|
local_metadata,
|
||||||
|
missing_keys,
|
||||||
|
unexpected_keys,
|
||||||
|
weight_shape=(self.num_embeddings, self.embedding_dim),
|
||||||
|
)
|
||||||
|
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
self.bias = None
|
self.bias = None
|
||||||
return None
|
return None
|
||||||
@ -801,6 +857,22 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
orig_shape=(self.out_features, self.in_features),
|
orig_shape=(self.out_features, self.in_features),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
elif self.quant_format == "mxfp8":
|
||||||
|
# MXFP8: E8M0 block scales stored as uint8 in safetensors
|
||||||
|
block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys,
|
||||||
|
dtype=torch.uint8)
|
||||||
|
|
||||||
|
if block_scale is None:
|
||||||
|
raise ValueError(f"Missing MXFP8 block scales for layer {layer_name}")
|
||||||
|
|
||||||
|
block_scale = block_scale.view(torch.float8_e8m0fnu)
|
||||||
|
|
||||||
|
params = layout_cls.Params(
|
||||||
|
scale=block_scale,
|
||||||
|
orig_dtype=MixedPrecisionOps._compute_dtype,
|
||||||
|
orig_shape=(self.out_features, self.in_features),
|
||||||
|
)
|
||||||
|
|
||||||
elif self.quant_format == "nvfp4":
|
elif self.quant_format == "nvfp4":
|
||||||
# NVFP4: tensor_scale (weight_scale_2) + block_scale (weight_scale)
|
# NVFP4: tensor_scale (weight_scale_2) + block_scale (weight_scale)
|
||||||
tensor_scale = self._load_scale_param(state_dict, prefix, "weight_scale_2", device, manually_loaded_keys)
|
tensor_scale = self._load_scale_param(state_dict, prefix, "weight_scale_2", device, manually_loaded_keys)
|
||||||
@ -950,12 +1022,15 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):
|
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):
|
||||||
fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular
|
fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular
|
||||||
nvfp4_compute = comfy.model_management.supports_nvfp4_compute(load_device)
|
nvfp4_compute = comfy.model_management.supports_nvfp4_compute(load_device)
|
||||||
|
mxfp8_compute = comfy.model_management.supports_mxfp8_compute(load_device)
|
||||||
|
|
||||||
if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config:
|
if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config:
|
||||||
logging.info("Using mixed precision operations")
|
logging.info("Using mixed precision operations")
|
||||||
disabled = set()
|
disabled = set()
|
||||||
if not nvfp4_compute:
|
if not nvfp4_compute:
|
||||||
disabled.add("nvfp4")
|
disabled.add("nvfp4")
|
||||||
|
if not mxfp8_compute:
|
||||||
|
disabled.add("mxfp8")
|
||||||
if not fp8_compute:
|
if not fp8_compute:
|
||||||
disabled.add("float8_e4m3fn")
|
disabled.add("float8_e4m3fn")
|
||||||
disabled.add("float8_e5m2")
|
disabled.add("float8_e5m2")
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
import torch
|
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.memory_management
|
import comfy.memory_management
|
||||||
|
import comfy_aimdo.host_buffer
|
||||||
|
import comfy_aimdo.torch
|
||||||
|
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
|
|
||||||
@ -12,18 +13,31 @@ def pin_memory(module):
|
|||||||
return
|
return
|
||||||
#FIXME: This is a RAM cache trigger event
|
#FIXME: This is a RAM cache trigger event
|
||||||
size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ])
|
size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ])
|
||||||
pin = torch.empty((size,), dtype=torch.uint8)
|
|
||||||
if comfy.model_management.pin_memory(pin):
|
if comfy.model_management.MAX_PINNED_MEMORY <= 0 or (comfy.model_management.TOTAL_PINNED_MEMORY + size) > comfy.model_management.MAX_PINNED_MEMORY:
|
||||||
module._pin = pin
|
|
||||||
else:
|
|
||||||
module.pin_failed = True
|
module.pin_failed = True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
hostbuf = comfy_aimdo.host_buffer.HostBuffer(size)
|
||||||
|
except RuntimeError:
|
||||||
|
module.pin_failed = True
|
||||||
|
return False
|
||||||
|
|
||||||
|
module._pin = comfy_aimdo.torch.hostbuf_to_tensor(hostbuf)
|
||||||
|
module._pin_hostbuf = hostbuf
|
||||||
|
comfy.model_management.TOTAL_PINNED_MEMORY += size
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def unpin_memory(module):
|
def unpin_memory(module):
|
||||||
if get_pin(module) is None:
|
if get_pin(module) is None:
|
||||||
return 0
|
return 0
|
||||||
size = module._pin.numel() * module._pin.element_size()
|
size = module._pin.numel() * module._pin.element_size()
|
||||||
comfy.model_management.unpin_memory(module._pin)
|
|
||||||
|
comfy.model_management.TOTAL_PINNED_MEMORY -= size
|
||||||
|
if comfy.model_management.TOTAL_PINNED_MEMORY < 0:
|
||||||
|
comfy.model_management.TOTAL_PINNED_MEMORY = 0
|
||||||
|
|
||||||
del module._pin
|
del module._pin
|
||||||
|
del module._pin_hostbuf
|
||||||
return size
|
return size
|
||||||
|
|||||||
@ -43,6 +43,18 @@ except ImportError as e:
|
|||||||
def get_layout_class(name):
|
def get_layout_class(name):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
_CK_MXFP8_AVAILABLE = False
|
||||||
|
if _CK_AVAILABLE:
|
||||||
|
try:
|
||||||
|
from comfy_kitchen.tensor import TensorCoreMXFP8Layout as _CKMxfp8Layout
|
||||||
|
_CK_MXFP8_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
logging.warning("comfy_kitchen does not support MXFP8, please update comfy_kitchen.")
|
||||||
|
|
||||||
|
if not _CK_MXFP8_AVAILABLE:
|
||||||
|
class _CKMxfp8Layout:
|
||||||
|
pass
|
||||||
|
|
||||||
import comfy.float
|
import comfy.float
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
@ -84,6 +96,31 @@ class _TensorCoreFP8LayoutBase(_CKFp8Layout):
|
|||||||
return qdata, params
|
return qdata, params
|
||||||
|
|
||||||
|
|
||||||
|
class TensorCoreMXFP8Layout(_CKMxfp8Layout):
|
||||||
|
@classmethod
|
||||||
|
def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
|
||||||
|
if tensor.dim() != 2:
|
||||||
|
raise ValueError(f"MXFP8 requires 2D tensor, got {tensor.dim()}D")
|
||||||
|
|
||||||
|
orig_dtype = tensor.dtype
|
||||||
|
orig_shape = tuple(tensor.shape)
|
||||||
|
|
||||||
|
padded_shape = cls.get_padded_shape(orig_shape)
|
||||||
|
needs_padding = padded_shape != orig_shape
|
||||||
|
|
||||||
|
if stochastic_rounding > 0:
|
||||||
|
qdata, block_scale = comfy.float.stochastic_round_quantize_mxfp8_by_block(tensor, pad_32x=needs_padding, seed=stochastic_rounding)
|
||||||
|
else:
|
||||||
|
qdata, block_scale = ck.quantize_mxfp8(tensor, pad_32x=needs_padding)
|
||||||
|
|
||||||
|
params = cls.Params(
|
||||||
|
scale=block_scale,
|
||||||
|
orig_dtype=orig_dtype,
|
||||||
|
orig_shape=orig_shape,
|
||||||
|
)
|
||||||
|
return qdata, params
|
||||||
|
|
||||||
|
|
||||||
class TensorCoreNVFP4Layout(_CKNvfp4Layout):
|
class TensorCoreNVFP4Layout(_CKNvfp4Layout):
|
||||||
@classmethod
|
@classmethod
|
||||||
def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
|
def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
|
||||||
@ -137,6 +174,8 @@ register_layout_class("TensorCoreFP8Layout", TensorCoreFP8Layout)
|
|||||||
register_layout_class("TensorCoreFP8E4M3Layout", TensorCoreFP8E4M3Layout)
|
register_layout_class("TensorCoreFP8E4M3Layout", TensorCoreFP8E4M3Layout)
|
||||||
register_layout_class("TensorCoreFP8E5M2Layout", TensorCoreFP8E5M2Layout)
|
register_layout_class("TensorCoreFP8E5M2Layout", TensorCoreFP8E5M2Layout)
|
||||||
register_layout_class("TensorCoreNVFP4Layout", TensorCoreNVFP4Layout)
|
register_layout_class("TensorCoreNVFP4Layout", TensorCoreNVFP4Layout)
|
||||||
|
if _CK_MXFP8_AVAILABLE:
|
||||||
|
register_layout_class("TensorCoreMXFP8Layout", TensorCoreMXFP8Layout)
|
||||||
|
|
||||||
QUANT_ALGOS = {
|
QUANT_ALGOS = {
|
||||||
"float8_e4m3fn": {
|
"float8_e4m3fn": {
|
||||||
@ -157,6 +196,14 @@ QUANT_ALGOS = {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if _CK_MXFP8_AVAILABLE:
|
||||||
|
QUANT_ALGOS["mxfp8"] = {
|
||||||
|
"storage_t": torch.float8_e4m3fn,
|
||||||
|
"parameters": {"weight_scale", "input_scale"},
|
||||||
|
"comfy_tensor_layout": "TensorCoreMXFP8Layout",
|
||||||
|
"group_size": 32,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
# Re-exports for backward compatibility
|
# Re-exports for backward compatibility
|
||||||
|
|||||||
27
comfy/sd.py
27
comfy/sd.py
@ -872,13 +872,16 @@ class VAE:
|
|||||||
pixels = torch.nn.functional.pad(pixels, (0, self.output_channels - pixels.shape[-1]), mode=mode, value=value)
|
pixels = torch.nn.functional.pad(pixels, (0, self.output_channels - pixels.shape[-1]), mode=mode, value=value)
|
||||||
return pixels
|
return pixels
|
||||||
|
|
||||||
|
def vae_output_dtype(self):
|
||||||
|
return model_management.intermediate_dtype()
|
||||||
|
|
||||||
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
|
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
|
||||||
steps = samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap)
|
steps = samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap)
|
||||||
steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap)
|
steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap)
|
||||||
steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap)
|
steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap)
|
||||||
pbar = comfy.utils.ProgressBar(steps)
|
pbar = comfy.utils.ProgressBar(steps)
|
||||||
|
|
||||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
|
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
||||||
output = self.process_output(
|
output = self.process_output(
|
||||||
(comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
|
(comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
|
||||||
comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
|
comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
|
||||||
@ -888,16 +891,16 @@ class VAE:
|
|||||||
|
|
||||||
def decode_tiled_1d(self, samples, tile_x=256, overlap=32):
|
def decode_tiled_1d(self, samples, tile_x=256, overlap=32):
|
||||||
if samples.ndim == 3:
|
if samples.ndim == 3:
|
||||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
|
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
||||||
else:
|
else:
|
||||||
og_shape = samples.shape
|
og_shape = samples.shape
|
||||||
samples = samples.reshape((og_shape[0], og_shape[1] * og_shape[2], -1))
|
samples = samples.reshape((og_shape[0], og_shape[1] * og_shape[2], -1))
|
||||||
decode_fn = lambda a: self.first_stage_model.decode(a.reshape((-1, og_shape[1], og_shape[2], a.shape[-1])).to(self.vae_dtype).to(self.device)).float()
|
decode_fn = lambda a: self.first_stage_model.decode(a.reshape((-1, og_shape[1], og_shape[2], a.shape[-1])).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
||||||
|
|
||||||
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device))
|
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device))
|
||||||
|
|
||||||
def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)):
|
def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)):
|
||||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
|
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
||||||
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device))
|
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device))
|
||||||
|
|
||||||
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
|
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
|
||||||
@ -906,7 +909,7 @@ class VAE:
|
|||||||
steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap)
|
steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap)
|
||||||
pbar = comfy.utils.ProgressBar(steps)
|
pbar = comfy.utils.ProgressBar(steps)
|
||||||
|
|
||||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
|
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
||||||
samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
||||||
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
||||||
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
||||||
@ -915,7 +918,7 @@ class VAE:
|
|||||||
|
|
||||||
def encode_tiled_1d(self, samples, tile_x=256 * 2048, overlap=64 * 2048):
|
def encode_tiled_1d(self, samples, tile_x=256 * 2048, overlap=64 * 2048):
|
||||||
if self.latent_dim == 1:
|
if self.latent_dim == 1:
|
||||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
|
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
||||||
out_channels = self.latent_channels
|
out_channels = self.latent_channels
|
||||||
upscale_amount = 1 / self.downscale_ratio
|
upscale_amount = 1 / self.downscale_ratio
|
||||||
else:
|
else:
|
||||||
@ -924,7 +927,7 @@ class VAE:
|
|||||||
tile_x = tile_x // extra_channel_size
|
tile_x = tile_x // extra_channel_size
|
||||||
overlap = overlap // extra_channel_size
|
overlap = overlap // extra_channel_size
|
||||||
upscale_amount = 1 / self.downscale_ratio
|
upscale_amount = 1 / self.downscale_ratio
|
||||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).reshape(1, out_channels, -1).float()
|
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).reshape(1, out_channels, -1).to(dtype=self.vae_output_dtype())
|
||||||
|
|
||||||
out = comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=self.output_device)
|
out = comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=self.output_device)
|
||||||
if self.latent_dim == 1:
|
if self.latent_dim == 1:
|
||||||
@ -933,7 +936,7 @@ class VAE:
|
|||||||
return out.reshape(samples.shape[0], self.latent_channels, extra_channel_size, -1)
|
return out.reshape(samples.shape[0], self.latent_channels, extra_channel_size, -1)
|
||||||
|
|
||||||
def encode_tiled_3d(self, samples, tile_t=9999, tile_x=512, tile_y=512, overlap=(1, 64, 64)):
|
def encode_tiled_3d(self, samples, tile_t=9999, tile_x=512, tile_y=512, overlap=(1, 64, 64)):
|
||||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
|
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
||||||
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device)
|
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device)
|
||||||
|
|
||||||
def decode(self, samples_in, vae_options={}):
|
def decode(self, samples_in, vae_options={}):
|
||||||
@ -951,9 +954,9 @@ class VAE:
|
|||||||
|
|
||||||
for x in range(0, samples_in.shape[0], batch_number):
|
for x in range(0, samples_in.shape[0], batch_number):
|
||||||
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
|
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
|
||||||
out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(self.output_device).float())
|
out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(self.output_device).to(dtype=self.vae_output_dtype()))
|
||||||
if pixel_samples is None:
|
if pixel_samples is None:
|
||||||
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
|
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
|
||||||
pixel_samples[x:x+batch_number] = out
|
pixel_samples[x:x+batch_number] = out
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
model_management.raise_non_oom(e)
|
model_management.raise_non_oom(e)
|
||||||
@ -1026,9 +1029,9 @@ class VAE:
|
|||||||
samples = None
|
samples = None
|
||||||
for x in range(0, pixel_samples.shape[0], batch_number):
|
for x in range(0, pixel_samples.shape[0], batch_number):
|
||||||
pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype).to(self.device)
|
pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype).to(self.device)
|
||||||
out = self.first_stage_model.encode(pixels_in).to(self.output_device).float()
|
out = self.first_stage_model.encode(pixels_in).to(self.output_device).to(dtype=self.vae_output_dtype())
|
||||||
if samples is None:
|
if samples is None:
|
||||||
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
|
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
|
||||||
samples[x:x + batch_number] = out
|
samples[x:x + batch_number] = out
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@ -20,6 +20,8 @@
|
|||||||
import torch
|
import torch
|
||||||
import math
|
import math
|
||||||
import struct
|
import struct
|
||||||
|
import ctypes
|
||||||
|
import os
|
||||||
import comfy.memory_management
|
import comfy.memory_management
|
||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -32,7 +34,7 @@ from einops import rearrange
|
|||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
import mmap
|
import threading
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
MMAP_TORCH_FILES = args.mmap_torch_files
|
MMAP_TORCH_FILES = args.mmap_torch_files
|
||||||
@ -81,14 +83,17 @@ _TYPES = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
def load_safetensors(ckpt):
|
def load_safetensors(ckpt):
|
||||||
f = open(ckpt, "rb")
|
import comfy_aimdo.model_mmap
|
||||||
mapping = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)
|
|
||||||
mv = memoryview(mapping)
|
|
||||||
|
|
||||||
header_size = struct.unpack("<Q", mapping[:8])[0]
|
f = open(ckpt, "rb", buffering=0)
|
||||||
header = json.loads(mapping[8:8+header_size].decode("utf-8"))
|
model_mmap = comfy_aimdo.model_mmap.ModelMMAP(ckpt)
|
||||||
|
file_size = os.path.getsize(ckpt)
|
||||||
|
mv = memoryview((ctypes.c_uint8 * file_size).from_address(model_mmap.get()))
|
||||||
|
|
||||||
mv = mv[8 + header_size:]
|
header_size = struct.unpack("<Q", mv[:8])[0]
|
||||||
|
header = json.loads(mv[8:8 + header_size].tobytes().decode("utf-8"))
|
||||||
|
|
||||||
|
mv = mv[(data_base_offset := 8 + header_size):]
|
||||||
|
|
||||||
sd = {}
|
sd = {}
|
||||||
for name, info in header.items():
|
for name, info in header.items():
|
||||||
@ -102,7 +107,14 @@ def load_safetensors(ckpt):
|
|||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
#We are working with read-only RAM by design
|
#We are working with read-only RAM by design
|
||||||
warnings.filterwarnings("ignore", message="The given buffer is not writable")
|
warnings.filterwarnings("ignore", message="The given buffer is not writable")
|
||||||
sd[name] = torch.frombuffer(mv[start:end], dtype=_TYPES[info["dtype"]]).view(info["shape"])
|
tensor = torch.frombuffer(mv[start:end], dtype=_TYPES[info["dtype"]]).view(info["shape"])
|
||||||
|
storage = tensor.untyped_storage()
|
||||||
|
setattr(storage,
|
||||||
|
"_comfy_tensor_file_slice",
|
||||||
|
comfy.memory_management.TensorFileSlice(f, threading.get_ident(), data_base_offset + start, end - start))
|
||||||
|
setattr(storage, "_comfy_tensor_mmap_refs", (model_mmap, mv))
|
||||||
|
setattr(storage, "_comfy_tensor_mmap_touched", False)
|
||||||
|
sd[name] = tensor
|
||||||
|
|
||||||
return sd, header.get("__metadata__", {}),
|
return sd, header.get("__metadata__", {}),
|
||||||
|
|
||||||
|
|||||||
@ -1 +1 @@
|
|||||||
comfyui_manager==4.1b2
|
comfyui_manager==4.1b4
|
||||||
@ -32,7 +32,7 @@ async def cache_control(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if request.path.endswith(".js") or request.path.endswith(".css") or is_entry_point:
|
if request.path.endswith(".js") or request.path.endswith(".css") or is_entry_point:
|
||||||
response.headers.setdefault("Cache-Control", "no-cache")
|
response.headers.setdefault("Cache-Control", "no-store")
|
||||||
return response
|
return response
|
||||||
|
|
||||||
# Early return for non-image files - no cache headers needed
|
# Early return for non-image files - no cache headers needed
|
||||||
|
|||||||
6
nodes.py
6
nodes.py
@ -1724,6 +1724,8 @@ class LoadImage:
|
|||||||
output_masks = []
|
output_masks = []
|
||||||
w, h = None, None
|
w, h = None, None
|
||||||
|
|
||||||
|
dtype = comfy.model_management.intermediate_dtype()
|
||||||
|
|
||||||
for i in ImageSequence.Iterator(img):
|
for i in ImageSequence.Iterator(img):
|
||||||
i = node_helpers.pillow(ImageOps.exif_transpose, i)
|
i = node_helpers.pillow(ImageOps.exif_transpose, i)
|
||||||
|
|
||||||
@ -1748,8 +1750,8 @@ class LoadImage:
|
|||||||
mask = 1. - torch.from_numpy(mask)
|
mask = 1. - torch.from_numpy(mask)
|
||||||
else:
|
else:
|
||||||
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
|
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
|
||||||
output_images.append(image)
|
output_images.append(image.to(dtype=dtype))
|
||||||
output_masks.append(mask.unsqueeze(0))
|
output_masks.append(mask.unsqueeze(0).to(dtype=dtype))
|
||||||
|
|
||||||
if img.format == "MPO":
|
if img.format == "MPO":
|
||||||
break # ignore all frames except the first one for MPO format
|
break # ignore all frames except the first one for MPO format
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
comfyui-frontend-package==1.41.19
|
comfyui-frontend-package==1.41.20
|
||||||
comfyui-workflow-templates==0.9.21
|
comfyui-workflow-templates==0.9.21
|
||||||
comfyui-embedded-docs==0.4.3
|
comfyui-embedded-docs==0.4.3
|
||||||
torch
|
torch
|
||||||
@ -23,7 +23,7 @@ SQLAlchemy
|
|||||||
filelock
|
filelock
|
||||||
av>=14.2.0
|
av>=14.2.0
|
||||||
comfy-kitchen>=0.2.8
|
comfy-kitchen>=0.2.8
|
||||||
comfy-aimdo>=0.2.10
|
comfy-aimdo>=0.2.12
|
||||||
requests
|
requests
|
||||||
simpleeval>=1.0.0
|
simpleeval>=1.0.0
|
||||||
blake3
|
blake3
|
||||||
|
|||||||
@ -310,7 +310,7 @@ class PromptServer():
|
|||||||
@routes.get("/")
|
@routes.get("/")
|
||||||
async def get_root(request):
|
async def get_root(request):
|
||||||
response = web.FileResponse(os.path.join(self.web_root, "index.html"))
|
response = web.FileResponse(os.path.join(self.web_root, "index.html"))
|
||||||
response.headers['Cache-Control'] = 'no-cache'
|
response.headers['Cache-Control'] = 'no-store, must-revalidate'
|
||||||
response.headers["Pragma"] = "no-cache"
|
response.headers["Pragma"] = "no-cache"
|
||||||
response.headers["Expires"] = "0"
|
response.headers["Expires"] = "0"
|
||||||
return response
|
return response
|
||||||
|
|||||||
@ -28,31 +28,31 @@ CACHE_SCENARIOS = [
|
|||||||
},
|
},
|
||||||
# JavaScript/CSS scenarios
|
# JavaScript/CSS scenarios
|
||||||
{
|
{
|
||||||
"name": "js_no_cache",
|
"name": "js_no_store",
|
||||||
"path": "/script.js",
|
"path": "/script.js",
|
||||||
"status": 200,
|
"status": 200,
|
||||||
"expected_cache": "no-cache",
|
"expected_cache": "no-store",
|
||||||
"should_have_header": True,
|
"should_have_header": True,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "css_no_cache",
|
"name": "css_no_store",
|
||||||
"path": "/styles.css",
|
"path": "/styles.css",
|
||||||
"status": 200,
|
"status": 200,
|
||||||
"expected_cache": "no-cache",
|
"expected_cache": "no-store",
|
||||||
"should_have_header": True,
|
"should_have_header": True,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "index_json_no_cache",
|
"name": "index_json_no_store",
|
||||||
"path": "/api/index.json",
|
"path": "/api/index.json",
|
||||||
"status": 200,
|
"status": 200,
|
||||||
"expected_cache": "no-cache",
|
"expected_cache": "no-store",
|
||||||
"should_have_header": True,
|
"should_have_header": True,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "localized_index_json_no_cache",
|
"name": "localized_index_json_no_store",
|
||||||
"path": "/templates/index.zh.json",
|
"path": "/templates/index.zh.json",
|
||||||
"status": 200,
|
"status": 200,
|
||||||
"expected_cache": "no-cache",
|
"expected_cache": "no-store",
|
||||||
"should_have_header": True,
|
"should_have_header": True,
|
||||||
},
|
},
|
||||||
# Non-matching files
|
# Non-matching files
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user