diff --git a/README.md b/README.md index 56b7966cf..62c4f528c 100644 --- a/README.md +++ b/README.md @@ -38,6 +38,8 @@ ComfyUI lets you design and execute advanced stable diffusion pipelines using a ## Get Started +### Local + #### [Desktop Application](https://www.comfy.org/download) - The easiest way to get started. - Available on Windows & macOS. @@ -49,8 +51,13 @@ ComfyUI lets you design and execute advanced stable diffusion pipelines using a #### [Manual Install](#manual-install-windows-linux) Supports all operating systems and GPU types (NVIDIA, AMD, Intel, Apple Silicon, Ascend). -## [Examples](https://comfyanonymous.github.io/ComfyUI_examples/) -See what ComfyUI can do with the [example workflows](https://comfyanonymous.github.io/ComfyUI_examples/). +### Cloud + +#### [Comfy Cloud](https://www.comfy.org/cloud) +- Our official paid cloud version for those who can't afford local hardware. + +## Examples +See what ComfyUI can do with the [newer template workflows](https://comfy.org/workflows) or old [example workflows](https://comfyanonymous.github.io/ComfyUI_examples/). ## Features - Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything. diff --git a/comfy/cli_args.py b/comfy/cli_args.py index e9832acaf..0a0bf2f30 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -83,6 +83,8 @@ fpte_group.add_argument("--fp16-text-enc", action="store_true", help="Store text fpte_group.add_argument("--fp32-text-enc", action="store_true", help="Store text encoder weights in fp32.") fpte_group.add_argument("--bf16-text-enc", action="store_true", help="Store text encoder weights in bf16.") +parser.add_argument("--fp16-intermediates", action="store_true", help="Experimental: Use fp16 for intermediate tensors between nodes instead of fp32.") + parser.add_argument("--force-channels-last", action="store_true", help="Force channels last format when inferencing the models.") parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.") diff --git a/comfy/float.py b/comfy/float.py index 88c47cd80..184b3d6d0 100644 --- a/comfy/float.py +++ b/comfy/float.py @@ -209,3 +209,39 @@ def stochastic_round_quantize_nvfp4_by_block(x, per_tensor_scale, pad_16x, seed= output_block[i:i + slice_size].copy_(block) return output_fp4, to_blocked(output_block, flatten=False) + + +def stochastic_round_quantize_mxfp8_by_block(x, pad_32x, seed=0): + def roundup(x_val, multiple): + return ((x_val + multiple - 1) // multiple) * multiple + + if pad_32x: + rows, cols = x.shape + padded_rows = roundup(rows, 32) + padded_cols = roundup(cols, 32) + if padded_rows != rows or padded_cols != cols: + x = torch.nn.functional.pad(x, (0, padded_cols - cols, 0, padded_rows - rows)) + + F8_E4M3_MAX = 448.0 + E8M0_BIAS = 127 + BLOCK_SIZE = 32 + + rows, cols = x.shape + x_blocked = x.reshape(rows, -1, BLOCK_SIZE) + max_abs = torch.amax(torch.abs(x_blocked), dim=-1) + + # E8M0 block scales (power-of-2 exponents) + scale_needed = torch.clamp(max_abs.float() / F8_E4M3_MAX, min=2**(-127)) + exp_biased = torch.clamp(torch.ceil(torch.log2(scale_needed)).to(torch.int32) + E8M0_BIAS, 0, 254) + block_scales_e8m0 = exp_biased.to(torch.uint8) + + zero_mask = (max_abs == 0) + block_scales_f32 = (block_scales_e8m0.to(torch.int32) << 23).view(torch.float32) + block_scales_f32 = torch.where(zero_mask, torch.ones_like(block_scales_f32), block_scales_f32) + + # Scale per-block then stochastic round + data_scaled = (x_blocked.float() / block_scales_f32.unsqueeze(-1)).reshape(rows, cols) + output_fp8 = stochastic_rounding(data_scaled, torch.float8_e4m3fn, seed=seed) + + block_scales_e8m0 = torch.where(zero_mask, torch.zeros_like(block_scales_e8m0), block_scales_e8m0) + return output_fp8, to_blocked(block_scales_e8m0, flatten=False).view(torch.float8_e8m0fnu) diff --git a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py index 5b57dfc5e..9f14f64a5 100644 --- a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py +++ b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py @@ -11,6 +11,7 @@ from .causal_conv3d import CausalConv3d from .pixel_norm import PixelNorm from ..model import PixArtAlphaCombinedTimestepSizeEmbeddings import comfy.ops +import comfy.model_management from comfy.ldm.modules.diffusionmodules.model import torch_cat_if_needed ops = comfy.ops.disable_weight_init @@ -536,7 +537,7 @@ class Decoder(nn.Module): mark_conv3d_ended(self.conv_out) sample = self.conv_out(sample, causal=self.causal) if sample is not None and sample.shape[2] > 0: - output.append(sample) + output.append(sample.to(comfy.model_management.intermediate_device())) return up_block = self.up_blocks[idx] diff --git a/comfy/memory_management.py b/comfy/memory_management.py index 0b7da2852..563224098 100644 --- a/comfy/memory_management.py +++ b/comfy/memory_management.py @@ -1,9 +1,68 @@ import math +import ctypes +import threading +import dataclasses import torch from typing import NamedTuple from comfy.quant_ops import QuantizedTensor + +class TensorFileSlice(NamedTuple): + file_ref: object + thread_id: int + offset: int + size: int + + +def read_tensor_file_slice_into(tensor, destination): + + if isinstance(tensor, QuantizedTensor): + if not isinstance(destination, QuantizedTensor): + return False + if tensor._layout_cls != destination._layout_cls: + return False + + if not read_tensor_file_slice_into(tensor._qdata, destination._qdata): + return False + + dst_orig_dtype = destination._params.orig_dtype + destination._params.copy_from(tensor._params, non_blocking=False) + destination._params = dataclasses.replace(destination._params, orig_dtype=dst_orig_dtype) + return True + + info = getattr(tensor.untyped_storage(), "_comfy_tensor_file_slice", None) + if info is None: + return False + + file_obj = info.file_ref + if (destination.device.type != "cpu" + or file_obj is None + or threading.get_ident() != info.thread_id + or destination.numel() * destination.element_size() < info.size): + return False + + if info.size == 0: + return True + + buf_type = ctypes.c_ubyte * info.size + view = memoryview(buf_type.from_address(destination.data_ptr())) + + try: + file_obj.seek(info.offset) + done = 0 + while done < info.size: + try: + n = file_obj.readinto(view[done:]) + except OSError: + return False + if n <= 0: + return False + done += n + return True + finally: + view.release() + class TensorGeometry(NamedTuple): shape: any dtype: torch.dtype diff --git a/comfy/model_management.py b/comfy/model_management.py index 81c89b180..442d5a40a 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -505,6 +505,28 @@ def module_size(module): module_mem += t.nbytes return module_mem +def module_mmap_residency(module, free=False): + mmap_touched_mem = 0 + module_mem = 0 + bounced_mmaps = set() + sd = module.state_dict() + for k in sd: + t = sd[k] + module_mem += t.nbytes + storage = t._qdata.untyped_storage() if isinstance(t, comfy.quant_ops.QuantizedTensor) else t.untyped_storage() + if not getattr(storage, "_comfy_tensor_mmap_touched", False): + continue + mmap_touched_mem += t.nbytes + if not free: + continue + storage._comfy_tensor_mmap_touched = False + mmap_obj = storage._comfy_tensor_mmap_refs[0] + if mmap_obj in bounced_mmaps: + continue + mmap_obj.bounce() + bounced_mmaps.add(mmap_obj) + return mmap_touched_mem, module_mem + class LoadedModel: def __init__(self, model): self._set_model(model) @@ -532,6 +554,9 @@ class LoadedModel: def model_memory(self): return self.model.model_size() + def model_mmap_residency(self, free=False): + return self.model.model_mmap_residency(free=free) + def model_loaded_memory(self): return self.model.loaded_size() @@ -633,7 +658,7 @@ def extra_reserved_memory(): def minimum_inference_memory(): return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory() -def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_required=0): +def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins_required=0, ram_required=0): cleanup_models_gc() unloaded_model = [] can_unload = [] @@ -646,13 +671,14 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_ can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i)) shift_model.currently_used = False - for x in sorted(can_unload): + can_unload_sorted = sorted(can_unload) + for x in can_unload_sorted: i = x[-1] memory_to_free = 1e32 - ram_to_free = 1e32 + pins_to_free = 1e32 if not DISABLE_SMART_MEMORY: memory_to_free = memory_required - get_free_memory(device) - ram_to_free = ram_required - get_free_ram() + pins_to_free = pins_required - get_free_ram() if current_loaded_models[i].model.is_dynamic() and for_dynamic: #don't actually unload dynamic models for the sake of other dynamic models #as that works on-demand. @@ -661,9 +687,18 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_ if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free): logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}") unloaded_model.append(i) - if ram_to_free > 0: + if pins_to_free > 0: + logging.debug(f"PIN Unloading {current_loaded_models[i].model.model.__class__.__name__}") + current_loaded_models[i].model.partially_unload_ram(pins_to_free) + + for x in can_unload_sorted: + i = x[-1] + ram_to_free = ram_required - psutil.virtual_memory().available + if ram_to_free <= 0 and i not in unloaded_model: + continue + resident_memory, _ = current_loaded_models[i].model_mmap_residency(free=True) + if resident_memory > 0: logging.debug(f"RAM Unloading {current_loaded_models[i].model.model.__class__.__name__}") - current_loaded_models[i].model.partially_unload_ram(ram_to_free) for i in sorted(unloaded_model, reverse=True): unloaded_models.append(current_loaded_models.pop(i)) @@ -729,17 +764,27 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu total_memory_required = {} + total_pins_required = {} total_ram_required = {} for loaded_model in models_to_load: - total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device) - #x2, one to make sure the OS can fit the model for loading in disk cache, and for us to do any pinning we - #want to do. - #FIXME: This should subtract off the to_load current pin consumption. - total_ram_required[loaded_model.device] = total_ram_required.get(loaded_model.device, 0) + loaded_model.model_memory() * 2 + device = loaded_model.device + total_memory_required[device] = total_memory_required.get(device, 0) + loaded_model.model_memory_required(device) + resident_memory, model_memory = loaded_model.model_mmap_residency() + pinned_memory = loaded_model.model.pinned_memory_size() + #FIXME: This can over-free the pins as it budgets to pin the entire model. We should + #make this JIT to keep as much pinned as possible. + pins_required = model_memory - pinned_memory + ram_required = model_memory - resident_memory + total_pins_required[device] = total_pins_required.get(device, 0) + pins_required + total_ram_required[device] = total_ram_required.get(device, 0) + ram_required for device in total_memory_required: if device != torch.device("cpu"): - free_memory(total_memory_required[device] * 1.1 + extra_mem, device, for_dynamic=free_for_dynamic, ram_required=total_ram_required[device]) + free_memory(total_memory_required[device] * 1.1 + extra_mem, + device, + for_dynamic=free_for_dynamic, + pins_required=total_pins_required[device], + ram_required=total_ram_required[device]) for device in total_memory_required: if device != torch.device("cpu"): @@ -1005,6 +1050,12 @@ def intermediate_device(): else: return torch.device("cpu") +def intermediate_dtype(): + if args.fp16_intermediates: + return torch.float16 + else: + return torch.float32 + def vae_device(): if args.cpu_vae: return torch.device("cpu") @@ -1225,6 +1276,11 @@ def cast_to_gathered(tensors, r, non_blocking=False, stream=None): dest_view = dest_views.pop(0) if tensor is None: continue + if comfy.memory_management.read_tensor_file_slice_into(tensor, dest_view): + continue + storage = tensor._qdata.untyped_storage() if isinstance(tensor, comfy.quant_ops.QuantizedTensor) else tensor.untyped_storage() + if hasattr(storage, "_comfy_tensor_mmap_touched"): + storage._comfy_tensor_mmap_touched = True dest_view.copy_(tensor, non_blocking=non_blocking) @@ -1662,6 +1718,19 @@ def supports_nvfp4_compute(device=None): return True +def supports_mxfp8_compute(device=None): + if not is_nvidia(): + return False + + if torch_version_numeric < (2, 10): + return False + + props = torch.cuda.get_device_properties(device) + if props.major < 10: + return False + + return True + def extended_fp16_support(): # TODO: check why some models work with fp16 on newer torch versions but not on older if torch_version_numeric < (2, 7): diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index bc3a8f446..c26d37db2 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -297,6 +297,9 @@ class ModelPatcher: self.size = comfy.model_management.module_size(self.model) return self.size + def model_mmap_residency(self, free=False): + return comfy.model_management.module_mmap_residency(self.model, free=free) + def get_ram_usage(self): return self.model_size() @@ -1063,6 +1066,10 @@ class ModelPatcher: return self.model.model_loaded_weight_memory - current_used + def pinned_memory_size(self): + # Pinned memory pressure tracking is only implemented for DynamicVram loading + return 0 + def partially_unload_ram(self, ram_to_unload): pass @@ -1653,6 +1660,16 @@ class ModelPatcherDynamic(ModelPatcher): return freed + def pinned_memory_size(self): + total = 0 + loading = self._load_list(for_dynamic=True) + for x in loading: + _, _, _, _, m, _ = x + pin = comfy.pinned_memory.get_pin(m) + if pin is not None: + total += pin.numel() * pin.element_size() + return total + def partially_unload_ram(self, ram_to_unload): loading = self._load_list(for_dynamic=True, default_device=self.offload_device) for x in loading: diff --git a/comfy/ops.py b/comfy/ops.py index 87b36b5c5..59c0df87d 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -306,6 +306,33 @@ class CastWeightBiasOp: bias_function = [] class disable_weight_init: + @staticmethod + def _lazy_load_from_state_dict(module, state_dict, prefix, local_metadata, + missing_keys, unexpected_keys, weight_shape, + bias_shape=None): + assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False) + prefix_len = len(prefix) + for k, v in state_dict.items(): + key = k[prefix_len:] + if key == "weight": + if not assign_to_params_buffers: + v = v.clone() + module.weight = torch.nn.Parameter(v, requires_grad=False) + elif bias_shape is not None and key == "bias" and v is not None: + if not assign_to_params_buffers: + v = v.clone() + module.bias = torch.nn.Parameter(v, requires_grad=False) + else: + unexpected_keys.append(k) + + if module.weight is None: + module.weight = torch.nn.Parameter(torch.zeros(weight_shape), requires_grad=False) + missing_keys.append(prefix + "weight") + + if bias_shape is not None and module.bias is None and getattr(module, "comfy_need_lazy_init_bias", False): + module.bias = torch.nn.Parameter(torch.zeros(bias_shape), requires_grad=False) + missing_keys.append(prefix + "bias") + class Linear(torch.nn.Linear, CastWeightBiasOp): def __init__(self, in_features, out_features, bias=True, device=None, dtype=None): @@ -333,29 +360,16 @@ class disable_weight_init: if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled: return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) - assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False) - prefix_len = len(prefix) - for k,v in state_dict.items(): - if k[prefix_len:] == "weight": - if not assign_to_params_buffers: - v = v.clone() - self.weight = torch.nn.Parameter(v, requires_grad=False) - elif k[prefix_len:] == "bias" and v is not None: - if not assign_to_params_buffers: - v = v.clone() - self.bias = torch.nn.Parameter(v, requires_grad=False) - else: - unexpected_keys.append(k) - - #Reconcile default construction of the weight if its missing. - if self.weight is None: - v = torch.zeros(self.in_features, self.out_features) - self.weight = torch.nn.Parameter(v, requires_grad=False) - missing_keys.append(prefix+"weight") - if self.bias is None and self.comfy_need_lazy_init_bias: - v = torch.zeros(self.out_features,) - self.bias = torch.nn.Parameter(v, requires_grad=False) - missing_keys.append(prefix+"bias") + disable_weight_init._lazy_load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + missing_keys, + unexpected_keys, + weight_shape=(self.in_features, self.out_features), + bias_shape=(self.out_features,), + ) def reset_parameters(self): @@ -547,6 +561,48 @@ class disable_weight_init: return super().forward(*args, **kwargs) class Embedding(torch.nn.Embedding, CastWeightBiasOp): + def __init__(self, num_embeddings, embedding_dim, padding_idx=None, max_norm=None, + norm_type=2.0, scale_grad_by_freq=False, sparse=False, _weight=None, + _freeze=False, device=None, dtype=None): + if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled: + super().__init__(num_embeddings, embedding_dim, padding_idx, max_norm, + norm_type, scale_grad_by_freq, sparse, _weight, + _freeze, device, dtype) + return + + torch.nn.Module.__init__(self) + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.padding_idx = padding_idx + self.max_norm = max_norm + self.norm_type = norm_type + self.scale_grad_by_freq = scale_grad_by_freq + self.sparse = sparse + # Keep shape/dtype visible for module introspection without reserving storage. + embedding_dtype = dtype if dtype is not None else torch.get_default_dtype() + self.weight = torch.nn.Parameter( + torch.empty((num_embeddings, embedding_dim), device="meta", dtype=embedding_dtype), + requires_grad=False, + ) + self.bias = None + self.weight_comfy_model_dtype = dtype + + def _load_from_state_dict(self, state_dict, prefix, local_metadata, + strict, missing_keys, unexpected_keys, error_msgs): + + if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled: + return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs) + disable_weight_init._lazy_load_from_state_dict( + self, + state_dict, + prefix, + local_metadata, + missing_keys, + unexpected_keys, + weight_shape=(self.num_embeddings, self.embedding_dim), + ) + def reset_parameters(self): self.bias = None return None @@ -801,6 +857,22 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec orig_shape=(self.out_features, self.in_features), ) + elif self.quant_format == "mxfp8": + # MXFP8: E8M0 block scales stored as uint8 in safetensors + block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys, + dtype=torch.uint8) + + if block_scale is None: + raise ValueError(f"Missing MXFP8 block scales for layer {layer_name}") + + block_scale = block_scale.view(torch.float8_e8m0fnu) + + params = layout_cls.Params( + scale=block_scale, + orig_dtype=MixedPrecisionOps._compute_dtype, + orig_shape=(self.out_features, self.in_features), + ) + elif self.quant_format == "nvfp4": # NVFP4: tensor_scale (weight_scale_2) + block_scale (weight_scale) tensor_scale = self._load_scale_param(state_dict, prefix, "weight_scale_2", device, manually_loaded_keys) @@ -950,12 +1022,15 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None): fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular nvfp4_compute = comfy.model_management.supports_nvfp4_compute(load_device) + mxfp8_compute = comfy.model_management.supports_mxfp8_compute(load_device) if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config: logging.info("Using mixed precision operations") disabled = set() if not nvfp4_compute: disabled.add("nvfp4") + if not mxfp8_compute: + disabled.add("mxfp8") if not fp8_compute: disabled.add("float8_e4m3fn") disabled.add("float8_e5m2") diff --git a/comfy/pinned_memory.py b/comfy/pinned_memory.py index 8acc327a7..f6fb806c4 100644 --- a/comfy/pinned_memory.py +++ b/comfy/pinned_memory.py @@ -1,6 +1,7 @@ -import torch import comfy.model_management import comfy.memory_management +import comfy_aimdo.host_buffer +import comfy_aimdo.torch from comfy.cli_args import args @@ -12,18 +13,31 @@ def pin_memory(module): return #FIXME: This is a RAM cache trigger event size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ]) - pin = torch.empty((size,), dtype=torch.uint8) - if comfy.model_management.pin_memory(pin): - module._pin = pin - else: + + if comfy.model_management.MAX_PINNED_MEMORY <= 0 or (comfy.model_management.TOTAL_PINNED_MEMORY + size) > comfy.model_management.MAX_PINNED_MEMORY: module.pin_failed = True return False + + try: + hostbuf = comfy_aimdo.host_buffer.HostBuffer(size) + except RuntimeError: + module.pin_failed = True + return False + + module._pin = comfy_aimdo.torch.hostbuf_to_tensor(hostbuf) + module._pin_hostbuf = hostbuf + comfy.model_management.TOTAL_PINNED_MEMORY += size return True def unpin_memory(module): if get_pin(module) is None: return 0 size = module._pin.numel() * module._pin.element_size() - comfy.model_management.unpin_memory(module._pin) + + comfy.model_management.TOTAL_PINNED_MEMORY -= size + if comfy.model_management.TOTAL_PINNED_MEMORY < 0: + comfy.model_management.TOTAL_PINNED_MEMORY = 0 + del module._pin + del module._pin_hostbuf return size diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index 15a4f457b..42ee08fb2 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -43,6 +43,18 @@ except ImportError as e: def get_layout_class(name): return None +_CK_MXFP8_AVAILABLE = False +if _CK_AVAILABLE: + try: + from comfy_kitchen.tensor import TensorCoreMXFP8Layout as _CKMxfp8Layout + _CK_MXFP8_AVAILABLE = True + except ImportError: + logging.warning("comfy_kitchen does not support MXFP8, please update comfy_kitchen.") + +if not _CK_MXFP8_AVAILABLE: + class _CKMxfp8Layout: + pass + import comfy.float # ============================================================================== @@ -84,6 +96,31 @@ class _TensorCoreFP8LayoutBase(_CKFp8Layout): return qdata, params +class TensorCoreMXFP8Layout(_CKMxfp8Layout): + @classmethod + def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False): + if tensor.dim() != 2: + raise ValueError(f"MXFP8 requires 2D tensor, got {tensor.dim()}D") + + orig_dtype = tensor.dtype + orig_shape = tuple(tensor.shape) + + padded_shape = cls.get_padded_shape(orig_shape) + needs_padding = padded_shape != orig_shape + + if stochastic_rounding > 0: + qdata, block_scale = comfy.float.stochastic_round_quantize_mxfp8_by_block(tensor, pad_32x=needs_padding, seed=stochastic_rounding) + else: + qdata, block_scale = ck.quantize_mxfp8(tensor, pad_32x=needs_padding) + + params = cls.Params( + scale=block_scale, + orig_dtype=orig_dtype, + orig_shape=orig_shape, + ) + return qdata, params + + class TensorCoreNVFP4Layout(_CKNvfp4Layout): @classmethod def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False): @@ -137,6 +174,8 @@ register_layout_class("TensorCoreFP8Layout", TensorCoreFP8Layout) register_layout_class("TensorCoreFP8E4M3Layout", TensorCoreFP8E4M3Layout) register_layout_class("TensorCoreFP8E5M2Layout", TensorCoreFP8E5M2Layout) register_layout_class("TensorCoreNVFP4Layout", TensorCoreNVFP4Layout) +if _CK_MXFP8_AVAILABLE: + register_layout_class("TensorCoreMXFP8Layout", TensorCoreMXFP8Layout) QUANT_ALGOS = { "float8_e4m3fn": { @@ -157,6 +196,14 @@ QUANT_ALGOS = { }, } +if _CK_MXFP8_AVAILABLE: + QUANT_ALGOS["mxfp8"] = { + "storage_t": torch.float8_e4m3fn, + "parameters": {"weight_scale", "input_scale"}, + "comfy_tensor_layout": "TensorCoreMXFP8Layout", + "group_size": 32, + } + # ============================================================================== # Re-exports for backward compatibility diff --git a/comfy/sd.py b/comfy/sd.py index 0812d8ea4..54fbe504c 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -872,13 +872,16 @@ class VAE: pixels = torch.nn.functional.pad(pixels, (0, self.output_channels - pixels.shape[-1]), mode=mode, value=value) return pixels + def vae_output_dtype(self): + return model_management.intermediate_dtype() + def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16): steps = samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap) steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap) steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap) pbar = comfy.utils.ProgressBar(steps) - decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float() + decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype()) output = self.process_output( (comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) + comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) + @@ -888,16 +891,16 @@ class VAE: def decode_tiled_1d(self, samples, tile_x=256, overlap=32): if samples.ndim == 3: - decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float() + decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype()) else: og_shape = samples.shape samples = samples.reshape((og_shape[0], og_shape[1] * og_shape[2], -1)) - decode_fn = lambda a: self.first_stage_model.decode(a.reshape((-1, og_shape[1], og_shape[2], a.shape[-1])).to(self.vae_dtype).to(self.device)).float() + decode_fn = lambda a: self.first_stage_model.decode(a.reshape((-1, og_shape[1], og_shape[2], a.shape[-1])).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype()) return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device)) def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)): - decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float() + decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype()) return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device)) def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): @@ -906,7 +909,7 @@ class VAE: steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap) pbar = comfy.utils.ProgressBar(steps) - encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float() + encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype()) samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar) samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar) samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar) @@ -915,7 +918,7 @@ class VAE: def encode_tiled_1d(self, samples, tile_x=256 * 2048, overlap=64 * 2048): if self.latent_dim == 1: - encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float() + encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype()) out_channels = self.latent_channels upscale_amount = 1 / self.downscale_ratio else: @@ -924,7 +927,7 @@ class VAE: tile_x = tile_x // extra_channel_size overlap = overlap // extra_channel_size upscale_amount = 1 / self.downscale_ratio - encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).reshape(1, out_channels, -1).float() + encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).reshape(1, out_channels, -1).to(dtype=self.vae_output_dtype()) out = comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=self.output_device) if self.latent_dim == 1: @@ -933,7 +936,7 @@ class VAE: return out.reshape(samples.shape[0], self.latent_channels, extra_channel_size, -1) def encode_tiled_3d(self, samples, tile_t=9999, tile_x=512, tile_y=512, overlap=(1, 64, 64)): - encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float() + encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype()) return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device) def decode(self, samples_in, vae_options={}): @@ -951,9 +954,9 @@ class VAE: for x in range(0, samples_in.shape[0], batch_number): samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device) - out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(self.output_device).float()) + out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(self.output_device).to(dtype=self.vae_output_dtype())) if pixel_samples is None: - pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device) + pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype()) pixel_samples[x:x+batch_number] = out except Exception as e: model_management.raise_non_oom(e) @@ -1026,9 +1029,9 @@ class VAE: samples = None for x in range(0, pixel_samples.shape[0], batch_number): pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype).to(self.device) - out = self.first_stage_model.encode(pixels_in).to(self.output_device).float() + out = self.first_stage_model.encode(pixels_in).to(self.output_device).to(dtype=self.vae_output_dtype()) if samples is None: - samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device) + samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype()) samples[x:x + batch_number] = out except Exception as e: diff --git a/comfy/utils.py b/comfy/utils.py index 6e1d14419..9931fe3b4 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -20,6 +20,8 @@ import torch import math import struct +import ctypes +import os import comfy.memory_management import safetensors.torch import numpy as np @@ -32,7 +34,7 @@ from einops import rearrange from comfy.cli_args import args import json import time -import mmap +import threading import warnings MMAP_TORCH_FILES = args.mmap_torch_files @@ -81,14 +83,17 @@ _TYPES = { } def load_safetensors(ckpt): - f = open(ckpt, "rb") - mapping = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) - mv = memoryview(mapping) + import comfy_aimdo.model_mmap - header_size = struct.unpack("=14.2.0 comfy-kitchen>=0.2.8 -comfy-aimdo>=0.2.10 +comfy-aimdo>=0.2.12 requests simpleeval>=1.0.0 blake3 diff --git a/server.py b/server.py index 76904ebc9..85a8964be 100644 --- a/server.py +++ b/server.py @@ -310,7 +310,7 @@ class PromptServer(): @routes.get("/") async def get_root(request): response = web.FileResponse(os.path.join(self.web_root, "index.html")) - response.headers['Cache-Control'] = 'no-cache' + response.headers['Cache-Control'] = 'no-store, must-revalidate' response.headers["Pragma"] = "no-cache" response.headers["Expires"] = "0" return response diff --git a/tests-unit/server_test/test_cache_control.py b/tests-unit/server_test/test_cache_control.py index fa68d9408..1d0366387 100644 --- a/tests-unit/server_test/test_cache_control.py +++ b/tests-unit/server_test/test_cache_control.py @@ -28,31 +28,31 @@ CACHE_SCENARIOS = [ }, # JavaScript/CSS scenarios { - "name": "js_no_cache", + "name": "js_no_store", "path": "/script.js", "status": 200, - "expected_cache": "no-cache", + "expected_cache": "no-store", "should_have_header": True, }, { - "name": "css_no_cache", + "name": "css_no_store", "path": "/styles.css", "status": 200, - "expected_cache": "no-cache", + "expected_cache": "no-store", "should_have_header": True, }, { - "name": "index_json_no_cache", + "name": "index_json_no_store", "path": "/api/index.json", "status": 200, - "expected_cache": "no-cache", + "expected_cache": "no-store", "should_have_header": True, }, { - "name": "localized_index_json_no_cache", + "name": "localized_index_json_no_store", "path": "/templates/index.zh.json", "status": 200, - "expected_cache": "no-cache", + "expected_cache": "no-store", "should_have_header": True, }, # Non-matching files