diff --git a/README.md b/README.md index 91fb510e1..ed857df9f 100644 --- a/README.md +++ b/README.md @@ -81,6 +81,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith - [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/) - [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/) - [Wan 2.2](https://comfyanonymous.github.io/ComfyUI_examples/wan22/) + - [Hunyuan Video 1.5](https://docs.comfy.org/tutorials/video/hunyuan/hunyuan-video-1-5) - Audio Models - [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/) - [ACE Step](https://comfyanonymous.github.io/ComfyUI_examples/audio/) diff --git a/comfy/context_windows.py b/comfy/context_windows.py index 041f380f9..5c412d1c2 100644 --- a/comfy/context_windows.py +++ b/comfy/context_windows.py @@ -51,26 +51,36 @@ class ContextHandlerABC(ABC): class IndexListContextWindow(ContextWindowABC): - def __init__(self, index_list: list[int], dim: int=0): + def __init__(self, index_list: list[int], dim: int=0, total_frames: int=0): self.index_list = index_list self.context_length = len(index_list) self.dim = dim + self.total_frames = total_frames + self.center_ratio = (min(index_list) + max(index_list)) / (2 * total_frames) - def get_tensor(self, full: torch.Tensor, device=None, dim=None) -> torch.Tensor: + def get_tensor(self, full: torch.Tensor, device=None, dim=None, retain_index_list=[]) -> torch.Tensor: if dim is None: dim = self.dim if dim == 0 and full.shape[dim] == 1: return full - idx = [slice(None)] * dim + [self.index_list] - return full[idx].to(device) + idx = tuple([slice(None)] * dim + [self.index_list]) + window = full[idx] + if retain_index_list: + idx = tuple([slice(None)] * dim + [retain_index_list]) + window[idx] = full[idx] + return window.to(device) def add_window(self, full: torch.Tensor, to_add: torch.Tensor, dim=None) -> torch.Tensor: if dim is None: dim = self.dim - idx = [slice(None)] * dim + [self.index_list] + idx = tuple([slice(None)] * dim + [self.index_list]) full[idx] += to_add return full + def get_region_index(self, num_regions: int) -> int: + region_idx = int(self.center_ratio * num_regions) + return min(max(region_idx, 0), num_regions - 1) + class IndexListCallbacks: EVALUATE_CONTEXT_WINDOWS = "evaluate_context_windows" @@ -94,7 +104,8 @@ class ContextFuseMethod: ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_conds_out', 'sub_conds', 'window']) class IndexListContextHandler(ContextHandlerABC): - def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1, closed_loop=False, dim=0): + def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1, + closed_loop: bool=False, dim:int=0, freenoise: bool=False, cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False): self.context_schedule = context_schedule self.fuse_method = fuse_method self.context_length = context_length @@ -103,13 +114,18 @@ class IndexListContextHandler(ContextHandlerABC): self.closed_loop = closed_loop self.dim = dim self._step = 0 + self.freenoise = freenoise + self.cond_retain_index_list = [int(x.strip()) for x in cond_retain_index_list.split(",")] if cond_retain_index_list else [] + self.split_conds_to_windows = split_conds_to_windows self.callbacks = {} def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool: # for now, assume first dim is batch - should have stored on BaseModel in actual implementation if x_in.size(self.dim) > self.context_length: - logging.info(f"Using context windows {self.context_length} for {x_in.size(self.dim)} frames.") + logging.info(f"Using context windows {self.context_length} with overlap {self.context_overlap} for {x_in.size(self.dim)} frames.") + if self.cond_retain_index_list: + logging.info(f"Retaining original cond for indexes: {self.cond_retain_index_list}") return True return False @@ -123,6 +139,11 @@ class IndexListContextHandler(ContextHandlerABC): return None # reuse or resize cond items to match context requirements resized_cond = [] + # if multiple conds, split based on primary region + if self.split_conds_to_windows and len(cond_in) > 1: + region = window.get_region_index(len(cond_in)) + logging.info(f"Splitting conds to windows; using region {region} for window {window[0]}-{window[-1]} with center ratio {window.center_ratio:.3f}") + cond_in = [cond_in[region]] # cond object is a list containing a dict - outer list is irrelevant, so just loop through it for actual_cond in cond_in: resized_actual_cond = actual_cond.copy() @@ -146,12 +167,19 @@ class IndexListContextHandler(ContextHandlerABC): # when in dictionary, look for tensors and CONDCrossAttn [comfy/conds.py] (has cond attr that is a tensor) for cond_key, cond_value in new_cond_item.items(): if isinstance(cond_value, torch.Tensor): - if cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim): + if (self.dim < cond_value.ndim and cond_value(self.dim) == x_in.size(self.dim)) or \ + (cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim)): new_cond_item[cond_key] = window.get_tensor(cond_value, device) + # Handle audio_embed (temporal dim is 1) + elif cond_key == "audio_embed" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor): + audio_cond = cond_value.cond + if audio_cond.ndim > 1 and audio_cond.size(1) == x_in.size(self.dim): + new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(audio_cond, device, dim=1)) # if has cond that is a Tensor, check if needs to be subset elif hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor): - if cond_value.cond.ndim < self.dim and cond_value.cond.size(0) == x_in.size(self.dim): - new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(cond_value.cond, device)) + if (self.dim < cond_value.cond.ndim and cond_value.cond.size(self.dim) == x_in.size(self.dim)) or \ + (cond_value.cond.ndim < self.dim and cond_value.cond.size(0) == x_in.size(self.dim)): + new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(cond_value.cond, device, retain_index_list=self.cond_retain_index_list)) elif cond_key == "num_video_frames": # for SVD new_cond_item[cond_key] = cond_value._copy_with(cond_value.cond) new_cond_item[cond_key].cond = window.context_length @@ -164,7 +192,7 @@ class IndexListContextHandler(ContextHandlerABC): return resized_cond def set_step(self, timestep: torch.Tensor, model_options: dict[str]): - mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep, rtol=0.0001) + mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep[0], rtol=0.0001) matches = torch.nonzero(mask) if torch.numel(matches) == 0: raise Exception("No sample_sigmas matched current timestep; something went wrong.") @@ -173,7 +201,7 @@ class IndexListContextHandler(ContextHandlerABC): def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]: full_length = x_in.size(self.dim) # TODO: choose dim based on model context_windows = self.context_schedule.func(full_length, self, model_options) - context_windows = [IndexListContextWindow(window, dim=self.dim) for window in context_windows] + context_windows = [IndexListContextWindow(window, dim=self.dim, total_frames=full_length) for window in context_windows] return context_windows def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]): @@ -250,8 +278,8 @@ class IndexListContextHandler(ContextHandlerABC): prev_weight = (bias_total / (bias_total + bias)) new_weight = (bias / (bias_total + bias)) # account for dims of tensors - idx_window = [slice(None)] * self.dim + [idx] - pos_window = [slice(None)] * self.dim + [pos] + idx_window = tuple([slice(None)] * self.dim + [idx]) + pos_window = tuple([slice(None)] * self.dim + [pos]) # apply new values conds_final[i][idx_window] = conds_final[i][idx_window] * prev_weight + sub_conds_out[i][pos_window] * new_weight biases_final[i][idx] = bias_total + bias @@ -287,6 +315,28 @@ def create_prepare_sampling_wrapper(model: ModelPatcher): ) +def _sampler_sample_wrapper(executor, guider, sigmas, extra_args, callback, noise, *args, **kwargs): + model_options = extra_args.get("model_options", None) + if model_options is None: + raise Exception("model_options not found in sampler_sample_wrapper; this should never happen, something went wrong.") + handler: IndexListContextHandler = model_options.get("context_handler", None) + if handler is None: + raise Exception("context_handler not found in sampler_sample_wrapper; this should never happen, something went wrong.") + if not handler.freenoise: + return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs) + noise = apply_freenoise(noise, handler.dim, handler.context_length, handler.context_overlap, extra_args["seed"]) + + return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs) + + +def create_sampler_sample_wrapper(model: ModelPatcher): + model.add_wrapper_with_key( + comfy.patcher_extension.WrappersMP.SAMPLER_SAMPLE, + "ContextWindows_sampler_sample", + _sampler_sample_wrapper + ) + + def match_weights_to_dim(weights: list[float], x_in: torch.Tensor, dim: int, device=None) -> torch.Tensor: total_dims = len(x_in.shape) weights_tensor = torch.Tensor(weights).to(device=device) @@ -538,3 +588,29 @@ def shift_window_to_end(window: list[int], num_frames: int): for i in range(len(window)): # 2) add end_delta to each val to slide windows to end window[i] = window[i] + end_delta + + +# https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved/blob/90fb1331201a4b29488089e4fbffc0d82cc6d0a9/animatediff/sample_settings.py#L465 +def apply_freenoise(noise: torch.Tensor, dim: int, context_length: int, context_overlap: int, seed: int): + logging.info("Context windows: Applying FreeNoise") + generator = torch.Generator(device='cpu').manual_seed(seed) + latent_video_length = noise.shape[dim] + delta = context_length - context_overlap + + for start_idx in range(0, latent_video_length - context_length, delta): + place_idx = start_idx + context_length + + actual_delta = min(delta, latent_video_length - place_idx) + if actual_delta <= 0: + break + + list_idx = torch.randperm(actual_delta, generator=generator, device='cpu') + start_idx + + source_slice = [slice(None)] * noise.ndim + source_slice[dim] = list_idx + target_slice = [slice(None)] * noise.ndim + target_slice[dim] = slice(place_idx, place_idx + actual_delta) + + noise[tuple(target_slice)] = noise[tuple(source_slice)] + + return noise diff --git a/comfy/ldm/lumina/model.py b/comfy/ldm/lumina/model.py index f1c1a0ec3..6c24fed9b 100644 --- a/comfy/ldm/lumina/model.py +++ b/comfy/ldm/lumina/model.py @@ -586,7 +586,6 @@ class NextDiT(nn.Module): cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute patches = transformer_options.get("patches", {}) - transformer_options = kwargs.get("transformer_options", {}) x_is_tensor = isinstance(x, torch.Tensor) img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options) freqs_cis = freqs_cis.to(img.device) diff --git a/comfy/model_base.py b/comfy/model_base.py index 9b76c285e..3cedd4f31 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -134,7 +134,7 @@ class BaseModel(torch.nn.Module): if not unet_config.get("disable_unet_model_creation", False): if model_config.custom_operations is None: fp8 = model_config.optimizations.get("fp8", False) - operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8, model_config=model_config) + operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, model_config=model_config) else: operations = model_config.custom_operations self.diffusion_model = unet_model(**unet_config, device=device, operations=operations) @@ -329,18 +329,6 @@ class BaseModel(torch.nn.Module): extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict)) unet_state_dict = self.diffusion_model.state_dict() - - if self.model_config.scaled_fp8 is not None: - unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8) - - # Save mixed precision metadata - if hasattr(self.model_config, 'layer_quant_config') and self.model_config.layer_quant_config: - metadata = { - "format_version": "1.0", - "layers": self.model_config.layer_quant_config - } - unet_state_dict["_quantization_metadata"] = metadata - unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict) if self.model_type == ModelType.V_PREDICTION: diff --git a/comfy/model_detection.py b/comfy/model_detection.py index 7d0517e61..fd1907627 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -6,20 +6,6 @@ import math import logging import torch - -def detect_layer_quantization(metadata): - quant_key = "_quantization_metadata" - if metadata is not None and quant_key in metadata: - quant_metadata = metadata.pop(quant_key) - quant_metadata = json.loads(quant_metadata) - if isinstance(quant_metadata, dict) and "layers" in quant_metadata: - logging.info(f"Found quantization metadata (version {quant_metadata.get('format_version', 'unknown')})") - return quant_metadata["layers"] - else: - raise ValueError("Invalid quantization metadata format") - return None - - def count_blocks(state_dict_keys, prefix_string): count = 0 while True: @@ -767,22 +753,11 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal if model_config is None and use_base_if_no_match: model_config = comfy.supported_models_base.BASE(unet_config) - scaled_fp8_key = "{}scaled_fp8".format(unet_key_prefix) - if scaled_fp8_key in state_dict: - scaled_fp8_weight = state_dict.pop(scaled_fp8_key) - model_config.scaled_fp8 = scaled_fp8_weight.dtype - if model_config.scaled_fp8 == torch.float32: - model_config.scaled_fp8 = torch.float8_e4m3fn - if scaled_fp8_weight.nelement() == 2: - model_config.optimizations["fp8"] = False - else: - model_config.optimizations["fp8"] = True - # Detect per-layer quantization (mixed precision) - layer_quant_config = detect_layer_quantization(metadata) - if layer_quant_config: - model_config.layer_quant_config = layer_quant_config - logging.info(f"Detected mixed precision quantization: {len(layer_quant_config)} layers quantized") + quant_config = comfy.utils.detect_layer_quantization(state_dict, unet_key_prefix) + if quant_config: + model_config.quant_config = quant_config + logging.info("Detected mixed precision quantization") return model_config diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index 3dcac3eef..215784874 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -126,27 +126,11 @@ class LowVramPatch: def __init__(self, key, patches, convert_func=None, set_func=None): self.key = key self.patches = patches - self.convert_func = convert_func + self.convert_func = convert_func # TODO: remove self.set_func = set_func def __call__(self, weight): - intermediate_dtype = weight.dtype - if self.convert_func is not None: - weight = self.convert_func(weight, inplace=False) - - if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops - intermediate_dtype = torch.float32 - out = comfy.lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype) - if self.set_func is None: - return comfy.float.stochastic_rounding(out, weight.dtype, seed=string_to_seed(self.key)) - else: - return self.set_func(out, seed=string_to_seed(self.key), return_weight=True) - - out = comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype) - if self.set_func is not None: - return self.set_func(out, seed=string_to_seed(self.key), return_weight=True).to(dtype=intermediate_dtype) - else: - return out + return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype) #The above patch logic may cast up the weight to fp32, and do math. Go with fp32 x 3 LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 3 diff --git a/comfy/ops.py b/comfy/ops.py index eae434e68..35237c9f7 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -23,6 +23,7 @@ from comfy.cli_args import args, PerformanceFeature import comfy.float import comfy.rmsnorm import contextlib +import json def run_every_op(): if torch.compiler.is_compiling(): @@ -422,22 +423,12 @@ def fp8_linear(self, input): if input.ndim == 3 or input.ndim == 2: w, bias, offload_stream = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype, offloadable=True) + scale_weight = torch.ones((), device=input.device, dtype=torch.float32) - scale_weight = self.scale_weight - scale_input = self.scale_input - if scale_weight is None: - scale_weight = torch.ones((), device=input.device, dtype=torch.float32) - else: - scale_weight = scale_weight.to(input.device) - - if scale_input is None: - scale_input = torch.ones((), device=input.device, dtype=torch.float32) - input = torch.clamp(input, min=-448, max=448, out=input) - layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype} - quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight) - else: - scale_input = scale_input.to(input.device) - quantized_input = QuantizedTensor.from_float(input, "TensorCoreFP8Layout", scale=scale_input, dtype=dtype) + scale_input = torch.ones((), device=input.device, dtype=torch.float32) + input = torch.clamp(input, min=-448, max=448, out=input) + layout_params_weight = {'scale': scale_input, 'orig_dtype': input_dtype} + quantized_input = QuantizedTensor(input.to(dtype).contiguous(), "TensorCoreFP8Layout", layout_params_weight) # Wrap weight in QuantizedTensor - this enables unified dispatch # Call F.linear - __torch_dispatch__ routes to fp8_linear handler in quant_ops.py! @@ -458,7 +449,7 @@ class fp8_ops(manual_cast): return None def forward_comfy_cast_weights(self, input): - if not self.training: + if len(self.weight_function) == 0 and len(self.bias_function) == 0: try: out = fp8_linear(self, input) if out is not None: @@ -471,59 +462,6 @@ class fp8_ops(manual_cast): uncast_bias_weight(self, weight, bias, offload_stream) return x -def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None): - logging.info("Using scaled fp8: fp8 matrix mult: {}, scale input: {}".format(fp8_matrix_mult, scale_input)) - class scaled_fp8_op(manual_cast): - class Linear(manual_cast.Linear): - def __init__(self, *args, **kwargs): - if override_dtype is not None: - kwargs['dtype'] = override_dtype - super().__init__(*args, **kwargs) - - def reset_parameters(self): - if not hasattr(self, 'scale_weight'): - self.scale_weight = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False) - - if not scale_input: - self.scale_input = None - - if not hasattr(self, 'scale_input'): - self.scale_input = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False) - return None - - def forward_comfy_cast_weights(self, input): - if fp8_matrix_mult: - out = fp8_linear(self, input) - if out is not None: - return out - - weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True) - - if weight.numel() < input.numel(): #TODO: optimize - x = torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias) - else: - x = torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias) - uncast_bias_weight(self, weight, bias, offload_stream) - return x - - def convert_weight(self, weight, inplace=False, **kwargs): - if inplace: - weight *= self.scale_weight.to(device=weight.device, dtype=weight.dtype) - return weight - else: - return weight.to(dtype=torch.float32) * self.scale_weight.to(device=weight.device, dtype=torch.float32) - - def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs): - weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed) - if return_weight: - return weight - if inplace_update: - self.weight.data.copy_(weight) - else: - self.weight = torch.nn.Parameter(weight, requires_grad=False) - - return scaled_fp8_op - CUBLAS_IS_AVAILABLE = False try: from cublas_ops import CublasLinear @@ -550,9 +488,9 @@ if CUBLAS_IS_AVAILABLE: from .quant_ops import QuantizedTensor, QUANT_ALGOS -def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False): +def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False): class MixedPrecisionOps(manual_cast): - _layer_quant_config = layer_quant_config + _quant_config = quant_config _compute_dtype = compute_dtype _full_precision_mm = full_precision_mm @@ -595,27 +533,38 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful manually_loaded_keys = [weight_key] - if layer_name not in MixedPrecisionOps._layer_quant_config: + layer_conf = state_dict.pop(f"{prefix}comfy_quant", None) + if layer_conf is not None: + layer_conf = json.loads(layer_conf.numpy().tobytes()) + + if layer_conf is None: self.weight = torch.nn.Parameter(weight.to(device=device, dtype=MixedPrecisionOps._compute_dtype), requires_grad=False) else: - quant_format = MixedPrecisionOps._layer_quant_config[layer_name].get("format", None) - if quant_format is None: + self.quant_format = layer_conf.get("format", None) + if not self._full_precision_mm: + self._full_precision_mm = layer_conf.get("full_precision_matrix_mult", False) + + if self.quant_format is None: raise ValueError(f"Unknown quantization format for layer {layer_name}") - qconfig = QUANT_ALGOS[quant_format] + qconfig = QUANT_ALGOS[self.quant_format] self.layout_type = qconfig["comfy_tensor_layout"] weight_scale_key = f"{prefix}weight_scale" + scale = state_dict.pop(weight_scale_key, None) + if scale is not None: + scale = scale.to(device) layout_params = { - 'scale': state_dict.pop(weight_scale_key, None), + 'scale': scale, 'orig_dtype': MixedPrecisionOps._compute_dtype, 'block_size': qconfig.get("group_size", None), } - if layout_params['scale'] is not None: + + if scale is not None: manually_loaded_keys.append(weight_scale_key) self.weight = torch.nn.Parameter( - QuantizedTensor(weight.to(device=device), self.layout_type, layout_params), + QuantizedTensor(weight.to(device=device, dtype=qconfig.get("storage_t", None)), self.layout_type, layout_params), requires_grad=False ) @@ -624,7 +573,7 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful _v = state_dict.pop(param_key, None) if _v is None: continue - setattr(self, param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False)) + self.register_parameter(param_name, torch.nn.Parameter(_v.to(device=device), requires_grad=False)) manually_loaded_keys.append(param_key) super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) @@ -633,6 +582,16 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful if key in missing_keys: missing_keys.remove(key) + def state_dict(self, *args, destination=None, prefix="", **kwargs): + sd = super().state_dict(*args, destination=destination, prefix=prefix, **kwargs) + if isinstance(self.weight, QuantizedTensor): + sd["{}weight_scale".format(prefix)] = self.weight._layout_params['scale'] + quant_conf = {"format": self.quant_format} + if self._full_precision_mm: + quant_conf["full_precision_matrix_mult"] = True + sd["{}comfy_quant".format(prefix)] = torch.frombuffer(json.dumps(quant_conf).encode('utf-8'), dtype=torch.uint8) + return sd + def _forward(self, input, weight, bias): return torch.nn.functional.linear(input, weight, bias) @@ -648,9 +607,8 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful if self._full_precision_mm or self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: return self.forward_comfy_cast_weights(input, *args, **kwargs) if (getattr(self, 'layout_type', None) is not None and - getattr(self, 'input_scale', None) is not None and not isinstance(input, QuantizedTensor)): - input = QuantizedTensor.from_float(input, self.layout_type, scale=self.input_scale, dtype=self.weight.dtype) + input = QuantizedTensor.from_float(input, self.layout_type, scale=getattr(self, 'input_scale', None), dtype=self.weight.dtype) return self._forward(input, self.weight, self.bias) def convert_weight(self, weight, inplace=False, **kwargs): @@ -661,7 +619,7 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful def set_weight(self, weight, inplace_update=False, seed=None, return_weight=False, **kwargs): if getattr(self, 'layout_type', None) is not None: - weight = QuantizedTensor.from_float(weight, self.layout_type, scale=None, dtype=self.weight.dtype, stochastic_rounding=seed, inplace_ops=True) + weight = QuantizedTensor.from_float(weight, self.layout_type, scale="recalculate", dtype=self.weight.dtype, stochastic_rounding=seed, inplace_ops=True) else: weight = weight.to(self.weight.dtype) if return_weight: @@ -670,17 +628,28 @@ def mixed_precision_ops(layer_quant_config={}, compute_dtype=torch.bfloat16, ful assert inplace_update is False # TODO: eventually remove the inplace_update stuff self.weight = torch.nn.Parameter(weight, requires_grad=False) + def _apply(self, fn, recurse=True): # This is to get torch.compile + moving weights to another device working + if recurse: + for module in self.children(): + module._apply(fn) + + for key, param in self._parameters.items(): + if param is None: + continue + self.register_parameter(key, torch.nn.Parameter(fn(param), requires_grad=False)) + for key, buf in self._buffers.items(): + if buf is not None: + self._buffers[key] = fn(buf) + return self + return MixedPrecisionOps -def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, model_config=None): +def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None): fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular - if model_config and hasattr(model_config, 'layer_quant_config') and model_config.layer_quant_config: - logging.info(f"Using mixed precision operations: {len(model_config.layer_quant_config)} quantized layers") - return mixed_precision_ops(model_config.layer_quant_config, compute_dtype, full_precision_mm=not fp8_compute) - - if scaled_fp8 is not None: - return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8) + if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config: + logging.info("Using mixed precision operations") + return mixed_precision_ops(model_config.quant_config, compute_dtype, full_precision_mm=not fp8_compute) if ( fp8_compute and diff --git a/comfy/quant_ops.py b/comfy/quant_ops.py index bb1fb860c..571d3f760 100644 --- a/comfy/quant_ops.py +++ b/comfy/quant_ops.py @@ -238,6 +238,9 @@ class QuantizedTensor(torch.Tensor): def is_contiguous(self, *arg, **kwargs): return self._qdata.is_contiguous(*arg, **kwargs) + def storage(self): + return self._qdata.storage() + # ============================================================================== # Generic Utilities (Layout-Agnostic Operations) # ============================================================================== @@ -249,12 +252,6 @@ def _create_transformed_qtensor(qt, transform_fn): def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout=None, op_name="to"): - if target_dtype is not None and target_dtype != qt.dtype: - logging.warning( - f"QuantizedTensor: dtype conversion requested to {target_dtype}, " - f"but not supported for quantized tensors. Ignoring dtype." - ) - if target_layout is not None and target_layout != torch.strided: logging.warning( f"QuantizedTensor: layout change requested to {target_layout}, " @@ -274,6 +271,8 @@ def _handle_device_transfer(qt, target_device, target_dtype=None, target_layout= logging.debug(f"QuantizedTensor.{op_name}: Moving from {current_device} to {target_device}") new_q_data = qt._qdata.to(device=target_device) new_params = _move_layout_params_to_device(qt._layout_params, target_device) + if target_dtype is not None: + new_params["orig_dtype"] = target_dtype new_qt = QuantizedTensor(new_q_data, qt._layout_type, new_params) logging.debug(f"QuantizedTensor.{op_name}: Created new tensor on {target_device}") return new_qt @@ -339,7 +338,9 @@ def generic_copy_(func, args, kwargs): # Copy from another quantized tensor qt_dest._qdata.copy_(src._qdata, non_blocking=non_blocking) qt_dest._layout_type = src._layout_type + orig_dtype = qt_dest._layout_params["orig_dtype"] _copy_layout_params_inplace(src._layout_params, qt_dest._layout_params, non_blocking=non_blocking) + qt_dest._layout_params["orig_dtype"] = orig_dtype else: # Copy from regular tensor - just copy raw data qt_dest._qdata.copy_(src) @@ -397,17 +398,20 @@ class TensorCoreFP8Layout(QuantizedLayout): def quantize(cls, tensor, scale=None, dtype=torch.float8_e4m3fn, stochastic_rounding=0, inplace_ops=False): orig_dtype = tensor.dtype - if scale is None: + if isinstance(scale, str) and scale == "recalculate": scale = torch.amax(tensor.abs()) / torch.finfo(dtype).max - if not isinstance(scale, torch.Tensor): - scale = torch.tensor(scale) - scale = scale.to(device=tensor.device, dtype=torch.float32) + if scale is not None: + if not isinstance(scale, torch.Tensor): + scale = torch.tensor(scale) + scale = scale.to(device=tensor.device, dtype=torch.float32) - if inplace_ops: - tensor *= (1.0 / scale).to(tensor.dtype) + if inplace_ops: + tensor *= (1.0 / scale).to(tensor.dtype) + else: + tensor = tensor * (1.0 / scale).to(tensor.dtype) else: - tensor = tensor * (1.0 / scale).to(tensor.dtype) + scale = torch.ones((), device=tensor.device, dtype=torch.float32) if stochastic_rounding > 0: tensor = comfy.float.stochastic_rounding(tensor, dtype=dtype, seed=stochastic_rounding) diff --git a/comfy/sd.py b/comfy/sd.py index 03bdb33d5..c350322f8 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -98,7 +98,7 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip): class CLIP: - def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, model_options={}): + def __init__(self, target=None, embedding_directory=None, no_init=False, tokenizer_data={}, parameters=0, state_dict=[], model_options={}): if no_init: return params = target.params.copy() @@ -129,6 +129,27 @@ class CLIP: self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram self.patcher.is_clip = True self.apply_hooks_to_conds = None + if len(state_dict) > 0: + if isinstance(state_dict, list): + for c in state_dict: + m, u = self.load_sd(c) + if len(m) > 0: + logging.warning("clip missing: {}".format(m)) + + if len(u) > 0: + logging.debug("clip unexpected: {}".format(u)) + else: + m, u = self.load_sd(state_dict, full_model=True) + if len(m) > 0: + m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m)) + if len(m_filter) > 0: + logging.warning("clip missing: {}".format(m)) + else: + logging.debug("clip missing: {}".format(m)) + + if len(u) > 0: + logging.debug("clip unexpected {}:".format(u)) + if params['device'] == load_device: model_management.load_models_gpu([self.patcher], force_full_load=True) self.layer_idx = None @@ -968,10 +989,8 @@ def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DI clip_data = [] for p in ckpt_paths: sd, metadata = comfy.utils.load_torch_file(p, safe_load=True, return_metadata=True) - if metadata is not None: - quant_metadata = metadata.get("_quantization_metadata", None) - if quant_metadata is not None: - sd["_quantization_metadata"] = quant_metadata + if model_options.get("custom_operations", None) is None: + sd, metadata = comfy.utils.convert_old_quants(sd, model_prefix="", metadata=metadata) clip_data.append(sd) return load_text_encoder_state_dicts(clip_data, embedding_directory=embedding_directory, clip_type=clip_type, model_options=model_options) @@ -1088,7 +1107,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=False, clip_g=True, t5=False) clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer elif clip_type == CLIPType.HIDREAM: - clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=False, clip_g=True, t5=False, llama=False, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None) + clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=False, clip_g=True, t5=False, llama=False, dtype_t5=None, dtype_llama=None) clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer else: clip_target.clip = sdxl_clip.SDXLRefinerClipModel @@ -1112,7 +1131,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None) elif clip_type == CLIPType.HIDREAM: clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**t5xxl_detect(clip_data), - clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None, llama_scaled_fp8=None) + clip_l=False, clip_g=False, t5=True, llama=False, dtype_llama=None) clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer else: #CLIPType.MOCHI clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data)) @@ -1141,7 +1160,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None) elif te_model == TEModel.LLAMA3_8: clip_target.clip = comfy.text_encoders.hidream.hidream_clip(**llama_detect(clip_data), - clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None, t5xxl_scaled_fp8=None) + clip_l=False, clip_g=False, t5=False, llama=True, dtype_t5=None) clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer elif te_model == TEModel.QWEN25_3B: clip_target.clip = comfy.text_encoders.omnigen2.te(**llama_detect(clip_data)) @@ -1169,7 +1188,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip clip_target.clip = comfy.text_encoders.sd3_clip.sd3_clip(clip_l=True, clip_g=False, t5=False) clip_target.tokenizer = comfy.text_encoders.sd3_clip.SD3Tokenizer elif clip_type == CLIPType.HIDREAM: - clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=True, clip_g=False, t5=False, llama=False, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None) + clip_target.clip = comfy.text_encoders.hidream.hidream_clip(clip_l=True, clip_g=False, t5=False, llama=False, dtype_t5=None, dtype_llama=None) clip_target.tokenizer = comfy.text_encoders.hidream.HiDreamTokenizer else: clip_target.clip = sd1_clip.SD1ClipModel @@ -1224,19 +1243,10 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip parameters = 0 for c in clip_data: - if "_quantization_metadata" in c: - c.pop("_quantization_metadata") parameters += comfy.utils.calculate_parameters(c) tokenizer_data, model_options = comfy.text_encoders.long_clipl.model_options_long_clip(c, tokenizer_data, model_options) - clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, model_options=model_options) - for c in clip_data: - m, u = clip.load_sd(c) - if len(m) > 0: - logging.warning("clip missing: {}".format(m)) - - if len(u) > 0: - logging.debug("clip unexpected: {}".format(u)) + clip = CLIP(clip_target, embedding_directory=embedding_directory, parameters=parameters, tokenizer_data=tokenizer_data, state_dict=clip_data, model_options=model_options) return clip def load_gligen(ckpt_path): @@ -1295,6 +1305,10 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c weight_dtype = comfy.utils.weight_dtype(sd, diffusion_model_prefix) load_device = model_management.get_torch_device() + custom_operations = model_options.get("custom_operations", None) + if custom_operations is None: + sd, metadata = comfy.utils.convert_old_quants(sd, diffusion_model_prefix, metadata=metadata) + model_config = model_detection.model_config_from_unet(sd, diffusion_model_prefix, metadata=metadata) if model_config is None: logging.warning("Warning, This is not a checkpoint file, trying to load it as a diffusion model only.") @@ -1303,18 +1317,22 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c return None return (diffusion_model, None, VAE(sd={}), None) # The VAE object is there to throw an exception if it's actually used' - unet_weight_dtype = list(model_config.supported_inference_dtypes) - if model_config.scaled_fp8 is not None: + if model_config.quant_config is not None: weight_dtype = None - model_config.custom_operations = model_options.get("custom_operations", None) + if custom_operations is not None: + model_config.custom_operations = custom_operations + unet_dtype = model_options.get("dtype", model_options.get("weight_dtype", None)) if unet_dtype is None: unet_dtype = model_management.unet_dtype(model_params=parameters, supported_dtypes=unet_weight_dtype, weight_dtype=weight_dtype) - manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) + if model_config.quant_config is not None: + manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes) + else: + manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) if model_config.clip_vision_prefix is not None: @@ -1332,22 +1350,33 @@ def load_state_dict_guess_config(sd, output_vae=True, output_clip=True, output_c vae = VAE(sd=vae_sd, metadata=metadata) if output_clip: + if te_model_options.get("custom_operations", None) is None: + scaled_fp8_list = [] + for k in list(sd.keys()): # Convert scaled fp8 to mixed ops + if k.endswith(".scaled_fp8"): + scaled_fp8_list.append(k[:-len("scaled_fp8")]) + + if len(scaled_fp8_list) > 0: + out_sd = {} + for k in sd: + skip = False + for pref in scaled_fp8_list: + skip = skip or k.startswith(pref) + if not skip: + out_sd[k] = sd[k] + + for pref in scaled_fp8_list: + quant_sd, qmetadata = comfy.utils.convert_old_quants(sd, pref, metadata={}) + for k in quant_sd: + out_sd[k] = quant_sd[k] + sd = out_sd + clip_target = model_config.clip_target(state_dict=sd) if clip_target is not None: clip_sd = model_config.process_clip_state_dict(sd) if len(clip_sd) > 0: parameters = comfy.utils.calculate_parameters(clip_sd) - clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, model_options=te_model_options) - m, u = clip.load_sd(clip_sd, full_model=True) - if len(m) > 0: - m_filter = list(filter(lambda a: ".logit_scale" not in a and ".transformer.text_projection.weight" not in a, m)) - if len(m_filter) > 0: - logging.warning("clip missing: {}".format(m)) - else: - logging.debug("clip missing: {}".format(m)) - - if len(u) > 0: - logging.debug("clip unexpected {}:".format(u)) + clip = CLIP(clip_target, embedding_directory=embedding_directory, tokenizer_data=clip_sd, parameters=parameters, state_dict=clip_sd, model_options=te_model_options) else: logging.warning("no CLIP/text encoder weights in checkpoint, the text encoder model will not be loaded.") @@ -1394,6 +1423,9 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None): if len(temp_sd) > 0: sd = temp_sd + custom_operations = model_options.get("custom_operations", None) + if custom_operations is None: + sd, metadata = comfy.utils.convert_old_quants(sd, "", metadata=metadata) parameters = comfy.utils.calculate_parameters(sd) weight_dtype = comfy.utils.weight_dtype(sd) @@ -1424,7 +1456,7 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None): offload_device = model_management.unet_offload_device() unet_weight_dtype = list(model_config.supported_inference_dtypes) - if model_config.scaled_fp8 is not None: + if model_config.quant_config is not None: weight_dtype = None if dtype is None: @@ -1432,12 +1464,15 @@ def load_diffusion_model_state_dict(sd, model_options={}, metadata=None): else: unet_dtype = dtype - if model_config.layer_quant_config is not None: + if model_config.quant_config is not None: manual_cast_dtype = model_management.unet_manual_cast(None, load_device, model_config.supported_inference_dtypes) else: manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device, model_config.supported_inference_dtypes) model_config.set_inference_dtype(unet_dtype, manual_cast_dtype) - model_config.custom_operations = model_options.get("custom_operations", model_config.custom_operations) + + if custom_operations is not None: + model_config.custom_operations = custom_operations + if model_options.get("fp8_optimizations", False): model_config.optimizations["fp8"] = True @@ -1476,6 +1511,9 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m if vae is not None: vae_sd = vae.get_sd() + if metadata is None: + metadata = {} + model_management.load_models_gpu(load_models, force_patch_weights=True) clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None sd = model.model.state_dict_for_saving(clip_sd, vae_sd, clip_vision_sd) diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 503a51843..962948dae 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -107,29 +107,17 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): config[k] = v operations = model_options.get("custom_operations", None) - scaled_fp8 = None - quantization_metadata = model_options.get("quantization_metadata", None) + quant_config = model_options.get("quantization_metadata", None) if operations is None: - layer_quant_config = None - if quantization_metadata is not None: - layer_quant_config = json.loads(quantization_metadata).get("layers", None) - - if layer_quant_config is not None: - operations = comfy.ops.mixed_precision_ops(layer_quant_config, dtype, full_precision_mm=True) - logging.info(f"Using MixedPrecisionOps for text encoder: {len(layer_quant_config)} quantized layers") + if quant_config is not None: + operations = comfy.ops.mixed_precision_ops(quant_config, dtype, full_precision_mm=True) + logging.info("Using MixedPrecisionOps for text encoder") else: - # Fallback to scaled_fp8_ops for backward compatibility - scaled_fp8 = model_options.get("scaled_fp8", None) - if scaled_fp8 is not None: - operations = comfy.ops.scaled_fp8_ops(fp8_matrix_mult=False, override_dtype=scaled_fp8) - else: - operations = comfy.ops.manual_cast + operations = comfy.ops.manual_cast self.operations = operations self.transformer = model_class(config, dtype, device, self.operations) - if scaled_fp8 is not None: - self.transformer.scaled_fp8 = torch.nn.Parameter(torch.tensor([], dtype=scaled_fp8)) self.num_layers = self.transformer.num_layers diff --git a/comfy/supported_models_base.py b/comfy/supported_models_base.py index e4bd74514..9fd84d329 100644 --- a/comfy/supported_models_base.py +++ b/comfy/supported_models_base.py @@ -49,8 +49,7 @@ class BASE: manual_cast_dtype = None custom_operations = None - scaled_fp8 = None - layer_quant_config = None # Per-layer quantization configuration for mixed precision + quant_config = None # quantization configuration for mixed precision optimizations = {"fp8": False} @classmethod diff --git a/comfy/text_encoders/cosmos.py b/comfy/text_encoders/cosmos.py index a1adb5242..448381fa9 100644 --- a/comfy/text_encoders/cosmos.py +++ b/comfy/text_encoders/cosmos.py @@ -7,10 +7,10 @@ from transformers import T5TokenizerFast class T5XXLModel(sd1_clip.SDClipModel): def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=True, model_options={}): textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_old_config_xxl.json") - t5xxl_scaled_fp8 = model_options.get("t5xxl_scaled_fp8", None) - if t5xxl_scaled_fp8 is not None: + t5xxl_quantization_metadata = model_options.get("t5xxl_quantization_metadata", None) + if t5xxl_quantization_metadata is not None: model_options = model_options.copy() - model_options["scaled_fp8"] = t5xxl_scaled_fp8 + model_options["quantization_metadata"] = t5xxl_quantization_metadata super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, zero_out_masked=attention_mask, model_options=model_options) @@ -30,12 +30,12 @@ class CosmosT5Tokenizer(sd1_clip.SD1Tokenizer): super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer) -def te(dtype_t5=None, t5xxl_scaled_fp8=None): +def te(dtype_t5=None, t5_quantization_metadata=None): class CosmosTEModel_(CosmosT5XXL): def __init__(self, device="cpu", dtype=None, model_options={}): - if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options: + if t5_quantization_metadata is not None: model_options = model_options.copy() - model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8 + model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata if dtype is None: dtype = dtype_t5 super().__init__(device=device, dtype=dtype, model_options=model_options) diff --git a/comfy/text_encoders/flux.py b/comfy/text_encoders/flux.py index 99f4812bb..21d93d757 100644 --- a/comfy/text_encoders/flux.py +++ b/comfy/text_encoders/flux.py @@ -63,12 +63,12 @@ class FluxClipModel(torch.nn.Module): else: return self.t5xxl.load_sd(sd) -def flux_clip(dtype_t5=None, t5xxl_scaled_fp8=None): +def flux_clip(dtype_t5=None, t5_quantization_metadata=None): class FluxClipModel_(FluxClipModel): def __init__(self, device="cpu", dtype=None, model_options={}): - if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options: + if t5_quantization_metadata is not None: model_options = model_options.copy() - model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8 + model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata super().__init__(dtype_t5=dtype_t5, device=device, dtype=dtype, model_options=model_options) return FluxClipModel_ @@ -159,15 +159,13 @@ class Flux2TEModel(sd1_clip.SD1ClipModel): out = out.reshape(out.shape[0], out.shape[1], -1) return out, pooled, extra -def flux2_te(dtype_llama=None, llama_scaled_fp8=None, llama_quantization_metadata=None, pruned=False): +def flux2_te(dtype_llama=None, llama_quantization_metadata=None, pruned=False): class Flux2TEModel_(Flux2TEModel): def __init__(self, device="cpu", dtype=None, model_options={}): - if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options: - model_options = model_options.copy() - model_options["scaled_fp8"] = llama_scaled_fp8 if dtype_llama is not None: dtype = dtype_llama if llama_quantization_metadata is not None: + model_options = model_options.copy() model_options["quantization_metadata"] = llama_quantization_metadata if pruned: model_options = model_options.copy() diff --git a/comfy/text_encoders/genmo.py b/comfy/text_encoders/genmo.py index 9dcf190a2..5daea8135 100644 --- a/comfy/text_encoders/genmo.py +++ b/comfy/text_encoders/genmo.py @@ -26,12 +26,12 @@ class MochiT5Tokenizer(sd1_clip.SD1Tokenizer): super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer) -def mochi_te(dtype_t5=None, t5xxl_scaled_fp8=None): +def mochi_te(dtype_t5=None, t5_quantization_metadata=None): class MochiTEModel_(MochiT5XXL): def __init__(self, device="cpu", dtype=None, model_options={}): - if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options: + if t5_quantization_metadata is not None: model_options = model_options.copy() - model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8 + model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata if dtype is None: dtype = dtype_t5 super().__init__(device=device, dtype=dtype, model_options=model_options) diff --git a/comfy/text_encoders/hidream.py b/comfy/text_encoders/hidream.py index dbcf52784..600b34480 100644 --- a/comfy/text_encoders/hidream.py +++ b/comfy/text_encoders/hidream.py @@ -142,14 +142,14 @@ class HiDreamTEModel(torch.nn.Module): return self.llama.load_sd(sd) -def hidream_clip(clip_l=True, clip_g=True, t5=True, llama=True, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None): +def hidream_clip(clip_l=True, clip_g=True, t5=True, llama=True, dtype_t5=None, dtype_llama=None, t5_quantization_metadata=None, llama_quantization_metadata=None): class HiDreamTEModel_(HiDreamTEModel): def __init__(self, device="cpu", dtype=None, model_options={}): - if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options: + if t5_quantization_metadata is not None: model_options = model_options.copy() - model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8 - if llama_scaled_fp8 is not None and "llama_scaled_fp8" not in model_options: + model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata + if llama_quantization_metadata is not None: model_options = model_options.copy() - model_options["llama_scaled_fp8"] = llama_scaled_fp8 + model_options["llama_quantization_metadata"] = llama_quantization_metadata super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, llama=llama, dtype_t5=dtype_t5, dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options) return HiDreamTEModel_ diff --git a/comfy/text_encoders/hunyuan_image.py b/comfy/text_encoders/hunyuan_image.py index ff04726e1..cd198036c 100644 --- a/comfy/text_encoders/hunyuan_image.py +++ b/comfy/text_encoders/hunyuan_image.py @@ -40,10 +40,10 @@ class HunyuanImageTokenizer(QwenImageTokenizer): class Qwen25_7BVLIModel(sd1_clip.SDClipModel): def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}): - llama_scaled_fp8 = model_options.get("qwen_scaled_fp8", None) - if llama_scaled_fp8 is not None: + llama_quantization_metadata = model_options.get("llama_quantization_metadata", None) + if llama_quantization_metadata is not None: model_options = model_options.copy() - model_options["scaled_fp8"] = llama_scaled_fp8 + model_options["quantization_metadata"] = llama_quantization_metadata super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=comfy.text_encoders.llama.Qwen25_7BVLI, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) @@ -91,12 +91,12 @@ class HunyuanImageTEModel(QwenImageTEModel): else: return super().load_sd(sd) -def te(byt5=True, dtype_llama=None, llama_scaled_fp8=None): +def te(byt5=True, dtype_llama=None, llama_quantization_metadata=None): class QwenImageTEModel_(HunyuanImageTEModel): def __init__(self, device="cpu", dtype=None, model_options={}): - if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options: + if llama_quantization_metadata is not None: model_options = model_options.copy() - model_options["qwen_scaled_fp8"] = llama_scaled_fp8 + model_options["llama_quantization_metadata"] = llama_quantization_metadata if dtype_llama is not None: dtype = dtype_llama super().__init__(byt5=byt5, device=device, dtype=dtype, model_options=model_options) diff --git a/comfy/text_encoders/hunyuan_video.py b/comfy/text_encoders/hunyuan_video.py index 0110517bb..a9a6c525e 100644 --- a/comfy/text_encoders/hunyuan_video.py +++ b/comfy/text_encoders/hunyuan_video.py @@ -6,7 +6,7 @@ from transformers import LlamaTokenizerFast import torch import os import numbers - +import comfy.utils def llama_detect(state_dict, prefix=""): out = {} @@ -14,12 +14,9 @@ def llama_detect(state_dict, prefix=""): if t5_key in state_dict: out["dtype_llama"] = state_dict[t5_key].dtype - scaled_fp8_key = "{}scaled_fp8".format(prefix) - if scaled_fp8_key in state_dict: - out["llama_scaled_fp8"] = state_dict[scaled_fp8_key].dtype - - if "_quantization_metadata" in state_dict: - out["llama_quantization_metadata"] = state_dict["_quantization_metadata"] + quant = comfy.utils.detect_layer_quantization(state_dict, prefix) + if quant is not None: + out["llama_quantization_metadata"] = quant return out @@ -31,10 +28,10 @@ class LLAMA3Tokenizer(sd1_clip.SDTokenizer): class LLAMAModel(sd1_clip.SDClipModel): def __init__(self, device="cpu", layer="hidden", layer_idx=-3, dtype=None, attention_mask=True, model_options={}, special_tokens={"start": 128000, "pad": 128258}): - llama_scaled_fp8 = model_options.get("llama_scaled_fp8", None) - if llama_scaled_fp8 is not None: + llama_quantization_metadata = model_options.get("llama_quantization_metadata", None) + if llama_quantization_metadata is not None: model_options = model_options.copy() - model_options["scaled_fp8"] = llama_scaled_fp8 + model_options["quantization_metadata"] = llama_quantization_metadata textmodel_json_config = {} vocab_size = model_options.get("vocab_size", None) @@ -161,11 +158,11 @@ class HunyuanVideoClipModel(torch.nn.Module): return self.llama.load_sd(sd) -def hunyuan_video_clip(dtype_llama=None, llama_scaled_fp8=None): +def hunyuan_video_clip(dtype_llama=None, llama_quantization_metadata=None): class HunyuanVideoClipModel_(HunyuanVideoClipModel): def __init__(self, device="cpu", dtype=None, model_options={}): - if llama_scaled_fp8 is not None and "llama_scaled_fp8" not in model_options: + if llama_quantization_metadata is not None: model_options = model_options.copy() - model_options["llama_scaled_fp8"] = llama_scaled_fp8 + model_options["llama_quantization_metadata"] = llama_quantization_metadata super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options) return HunyuanVideoClipModel_ diff --git a/comfy/text_encoders/lumina2.py b/comfy/text_encoders/lumina2.py index fd986e2c1..7a6cfdab2 100644 --- a/comfy/text_encoders/lumina2.py +++ b/comfy/text_encoders/lumina2.py @@ -40,7 +40,7 @@ class LuminaModel(sd1_clip.SD1ClipModel): super().__init__(device=device, dtype=dtype, name=name, clip_model=clip_model, model_options=model_options) -def te(dtype_llama=None, llama_scaled_fp8=None, model_type="gemma2_2b"): +def te(dtype_llama=None, llama_quantization_metadata=None, model_type="gemma2_2b"): if model_type == "gemma2_2b": model = Gemma2_2BModel elif model_type == "gemma3_4b": @@ -48,9 +48,9 @@ def te(dtype_llama=None, llama_scaled_fp8=None, model_type="gemma2_2b"): class LuminaTEModel_(LuminaModel): def __init__(self, device="cpu", dtype=None, model_options={}): - if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options: + if llama_quantization_metadata is not None: model_options = model_options.copy() - model_options["scaled_fp8"] = llama_scaled_fp8 + model_options["quantization_metadata"] = llama_quantization_metadata if dtype_llama is not None: dtype = dtype_llama super().__init__(device=device, dtype=dtype, name=model_type, model_options=model_options, clip_model=model) diff --git a/comfy/text_encoders/omnigen2.py b/comfy/text_encoders/omnigen2.py index 1a01b2dd4..50aa4121f 100644 --- a/comfy/text_encoders/omnigen2.py +++ b/comfy/text_encoders/omnigen2.py @@ -32,12 +32,12 @@ class Omnigen2Model(sd1_clip.SD1ClipModel): super().__init__(device=device, dtype=dtype, name="qwen25_3b", clip_model=Qwen25_3BModel, model_options=model_options) -def te(dtype_llama=None, llama_scaled_fp8=None): +def te(dtype_llama=None, llama_quantization_metadata=None): class Omnigen2TEModel_(Omnigen2Model): def __init__(self, device="cpu", dtype=None, model_options={}): - if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options: + if llama_quantization_metadata is not None: model_options = model_options.copy() - model_options["scaled_fp8"] = llama_scaled_fp8 + model_options["quantization_metadata"] = llama_quantization_metadata if dtype_llama is not None: dtype = dtype_llama super().__init__(device=device, dtype=dtype, model_options=model_options) diff --git a/comfy/text_encoders/ovis.py b/comfy/text_encoders/ovis.py index 81c9bd51c..5754424d2 100644 --- a/comfy/text_encoders/ovis.py +++ b/comfy/text_encoders/ovis.py @@ -55,12 +55,9 @@ class OvisTEModel(sd1_clip.SD1ClipModel): return out, pooled, {} -def te(dtype_llama=None, llama_scaled_fp8=None, llama_quantization_metadata=None): +def te(dtype_llama=None, llama_quantization_metadata=None): class OvisTEModel_(OvisTEModel): def __init__(self, device="cpu", dtype=None, model_options={}): - if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options: - model_options = model_options.copy() - model_options["scaled_fp8"] = llama_scaled_fp8 if dtype_llama is not None: dtype = dtype_llama if llama_quantization_metadata is not None: diff --git a/comfy/text_encoders/pixart_t5.py b/comfy/text_encoders/pixart_t5.py index 5f383de07..e5e5f18be 100644 --- a/comfy/text_encoders/pixart_t5.py +++ b/comfy/text_encoders/pixart_t5.py @@ -30,12 +30,12 @@ class PixArtTokenizer(sd1_clip.SD1Tokenizer): def __init__(self, embedding_directory=None, tokenizer_data={}): super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer) -def pixart_te(dtype_t5=None, t5xxl_scaled_fp8=None): +def pixart_te(dtype_t5=None, t5_quantization_metadata=None): class PixArtTEModel_(PixArtT5XXL): def __init__(self, device="cpu", dtype=None, model_options={}): - if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options: + if t5_quantization_metadata is not None: model_options = model_options.copy() - model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8 + model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata if dtype is None: dtype = dtype_t5 super().__init__(device=device, dtype=dtype, model_options=model_options) diff --git a/comfy/text_encoders/qwen_image.py b/comfy/text_encoders/qwen_image.py index c0d32a6ef..5c14dec23 100644 --- a/comfy/text_encoders/qwen_image.py +++ b/comfy/text_encoders/qwen_image.py @@ -85,12 +85,12 @@ class QwenImageTEModel(sd1_clip.SD1ClipModel): return out, pooled, extra -def te(dtype_llama=None, llama_scaled_fp8=None): +def te(dtype_llama=None, llama_quantization_metadata=None): class QwenImageTEModel_(QwenImageTEModel): def __init__(self, device="cpu", dtype=None, model_options={}): - if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options: + if llama_quantization_metadata is not None: model_options = model_options.copy() - model_options["scaled_fp8"] = llama_scaled_fp8 + model_options["quantization_metadata"] = llama_quantization_metadata if dtype_llama is not None: dtype = dtype_llama super().__init__(device=device, dtype=dtype, model_options=model_options) diff --git a/comfy/text_encoders/sd3_clip.py b/comfy/text_encoders/sd3_clip.py index ff5d412db..8b153c72b 100644 --- a/comfy/text_encoders/sd3_clip.py +++ b/comfy/text_encoders/sd3_clip.py @@ -6,14 +6,15 @@ import torch import os import comfy.model_management import logging +import comfy.utils class T5XXLModel(sd1_clip.SDClipModel): def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, attention_mask=False, model_options={}): textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_config_xxl.json") - t5xxl_scaled_fp8 = model_options.get("t5xxl_scaled_fp8", None) - if t5xxl_scaled_fp8 is not None: + t5xxl_quantization_metadata = model_options.get("t5xxl_quantization_metadata", None) + if t5xxl_quantization_metadata is not None: model_options = model_options.copy() - model_options["scaled_fp8"] = t5xxl_scaled_fp8 + model_options["quantization_metadata"] = t5xxl_quantization_metadata model_options = {**model_options, "model_name": "t5xxl"} super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options) @@ -25,9 +26,9 @@ def t5_xxl_detect(state_dict, prefix=""): if t5_key in state_dict: out["dtype_t5"] = state_dict[t5_key].dtype - scaled_fp8_key = "{}scaled_fp8".format(prefix) - if scaled_fp8_key in state_dict: - out["t5xxl_scaled_fp8"] = state_dict[scaled_fp8_key].dtype + quant = comfy.utils.detect_layer_quantization(state_dict, prefix) + if quant is not None: + out["t5_quantization_metadata"] = quant return out @@ -156,11 +157,11 @@ class SD3ClipModel(torch.nn.Module): else: return self.t5xxl.load_sd(sd) -def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5xxl_scaled_fp8=None, t5_attention_mask=False): +def sd3_clip(clip_l=True, clip_g=True, t5=True, dtype_t5=None, t5_quantization_metadata=None, t5_attention_mask=False): class SD3ClipModel_(SD3ClipModel): def __init__(self, device="cpu", dtype=None, model_options={}): - if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options: + if t5_quantization_metadata is not None: model_options = model_options.copy() - model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8 + model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, dtype_t5=dtype_t5, t5_attention_mask=t5_attention_mask, device=device, dtype=dtype, model_options=model_options) return SD3ClipModel_ diff --git a/comfy/text_encoders/wan.py b/comfy/text_encoders/wan.py index d50fa4b28..164a57edd 100644 --- a/comfy/text_encoders/wan.py +++ b/comfy/text_encoders/wan.py @@ -25,12 +25,12 @@ class WanT5Model(sd1_clip.SD1ClipModel): def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs): super().__init__(device=device, dtype=dtype, model_options=model_options, name="umt5xxl", clip_model=UMT5XXlModel, **kwargs) -def te(dtype_t5=None, t5xxl_scaled_fp8=None): +def te(dtype_t5=None, t5_quantization_metadata=None): class WanTEModel(WanT5Model): def __init__(self, device="cpu", dtype=None, model_options={}): - if t5xxl_scaled_fp8 is not None and "scaled_fp8" not in model_options: + if t5_quantization_metadata is not None: model_options = model_options.copy() - model_options["scaled_fp8"] = t5xxl_scaled_fp8 + model_options["quantization_metadata"] = t5_quantization_metadata if dtype_t5 is not None: dtype = dtype_t5 super().__init__(device=device, dtype=dtype, model_options=model_options) diff --git a/comfy/text_encoders/z_image.py b/comfy/text_encoders/z_image.py index bb9273b20..19adde0b7 100644 --- a/comfy/text_encoders/z_image.py +++ b/comfy/text_encoders/z_image.py @@ -34,12 +34,9 @@ class ZImageTEModel(sd1_clip.SD1ClipModel): super().__init__(device=device, dtype=dtype, name="qwen3_4b", clip_model=Qwen3_4BModel, model_options=model_options) -def te(dtype_llama=None, llama_scaled_fp8=None, llama_quantization_metadata=None): +def te(dtype_llama=None, llama_quantization_metadata=None): class ZImageTEModel_(ZImageTEModel): def __init__(self, device="cpu", dtype=None, model_options={}): - if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options: - model_options = model_options.copy() - model_options["scaled_fp8"] = llama_scaled_fp8 if dtype_llama is not None: dtype = dtype_llama if llama_quantization_metadata is not None: diff --git a/comfy/utils.py b/comfy/utils.py index 37485e497..89846bc95 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -29,6 +29,7 @@ import itertools from torch.nn.functional import interpolate from einops import rearrange from comfy.cli_args import args +import json MMAP_TORCH_FILES = args.mmap_torch_files DISABLE_MMAP = args.disable_mmap @@ -1194,3 +1195,68 @@ def unpack_latents(combined_latent, latent_shapes): else: output_tensors = combined_latent return output_tensors + +def detect_layer_quantization(state_dict, prefix): + for k in state_dict: + if k.startswith(prefix) and k.endswith(".comfy_quant"): + logging.info("Found quantization metadata version 1") + return {"mixed_ops": True} + return None + +def convert_old_quants(state_dict, model_prefix="", metadata={}): + if metadata is None: + metadata = {} + + quant_metadata = None + if "_quantization_metadata" not in metadata: + scaled_fp8_key = "{}scaled_fp8".format(model_prefix) + + if scaled_fp8_key in state_dict: + scaled_fp8_weight = state_dict[scaled_fp8_key] + scaled_fp8_dtype = scaled_fp8_weight.dtype + if scaled_fp8_dtype == torch.float32: + scaled_fp8_dtype = torch.float8_e4m3fn + + if scaled_fp8_weight.nelement() == 2: + full_precision_matrix_mult = True + else: + full_precision_matrix_mult = False + + out_sd = {} + layers = {} + for k in list(state_dict.keys()): + if not k.startswith(model_prefix): + out_sd[k] = state_dict[k] + continue + k_out = k + w = state_dict.pop(k) + layer = None + if k_out.endswith(".scale_weight"): + layer = k_out[:-len(".scale_weight")] + k_out = "{}.weight_scale".format(layer) + + if layer is not None: + layer_conf = {"format": "float8_e4m3fn"} # TODO: check if anyone did some non e4m3fn scaled checkpoints + if full_precision_matrix_mult: + layer_conf["full_precision_matrix_mult"] = full_precision_matrix_mult + layers[layer] = layer_conf + + if k_out.endswith(".scale_input"): + layer = k_out[:-len(".scale_input")] + k_out = "{}.input_scale".format(layer) + if w.item() == 1.0: + continue + + out_sd[k_out] = w + + state_dict = out_sd + quant_metadata = {"layers": layers} + else: + quant_metadata = json.loads(metadata["_quantization_metadata"]) + + if quant_metadata is not None: + layers = quant_metadata["layers"] + for k, v in layers.items(): + state_dict["{}.comfy_quant".format(k)] = torch.frombuffer(json.dumps(v).encode('utf-8'), dtype=torch.uint8) + + return state_dict, metadata diff --git a/comfy_extras/nodes_context_windows.py b/comfy_extras/nodes_context_windows.py index 1c3d9e697..3799a9004 100644 --- a/comfy_extras/nodes_context_windows.py +++ b/comfy_extras/nodes_context_windows.py @@ -26,6 +26,9 @@ class ContextWindowsManualNode(io.ComfyNode): io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."), io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."), io.Int.Input("dim", min=0, max=5, default=0, tooltip="The dimension to apply the context windows to."), + io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."), + #io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."), + #io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."), ], outputs=[ io.Model.Output(tooltip="The model with context windows applied during sampling."), @@ -34,7 +37,8 @@ class ContextWindowsManualNode(io.ComfyNode): ) @classmethod - def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int) -> io.Model: + def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, dim: int, freenoise: bool, + cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False) -> io.Model: model = model.clone() model.model_options["context_handler"] = comfy.context_windows.IndexListContextHandler( context_schedule=comfy.context_windows.get_matching_context_schedule(context_schedule), @@ -43,9 +47,15 @@ class ContextWindowsManualNode(io.ComfyNode): context_overlap=context_overlap, context_stride=context_stride, closed_loop=closed_loop, - dim=dim) + dim=dim, + freenoise=freenoise, + cond_retain_index_list=cond_retain_index_list, + split_conds_to_windows=split_conds_to_windows + ) # make memory usage calculation only take into account the context window latents comfy.context_windows.create_prepare_sampling_wrapper(model) + if freenoise: # no other use for this wrapper at this time + comfy.context_windows.create_sampler_sample_wrapper(model) return io.NodeOutput(model) class WanContextWindowsManualNode(ContextWindowsManualNode): @@ -68,14 +78,18 @@ class WanContextWindowsManualNode(ContextWindowsManualNode): io.Int.Input("context_stride", min=1, default=1, tooltip="The stride of the context window; only applicable to uniform schedules."), io.Boolean.Input("closed_loop", default=False, tooltip="Whether to close the context window loop; only applicable to looped schedules."), io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."), + io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."), + #io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."), + #io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."), ] return schema @classmethod - def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str) -> io.Model: + def execute(cls, model: io.Model.Type, context_length: int, context_overlap: int, context_schedule: str, context_stride: int, closed_loop: bool, fuse_method: str, freenoise: bool, + cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False) -> io.Model: context_length = max(((context_length - 1) // 4) + 1, 1) # at least length 1 context_overlap = max(((context_overlap - 1) // 4) + 1, 0) # at least overlap 0 - return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2) + return super().execute(model, context_length, context_overlap, context_schedule, context_stride, closed_loop, fuse_method, dim=2, freenoise=freenoise, cond_retain_index_list=cond_retain_index_list, split_conds_to_windows=split_conds_to_windows) class ContextWindowsExtension(ComfyExtension): diff --git a/tests-unit/comfy_quant/test_mixed_precision.py b/tests-unit/comfy_quant/test_mixed_precision.py index 63361309f..3a54941e6 100644 --- a/tests-unit/comfy_quant/test_mixed_precision.py +++ b/tests-unit/comfy_quant/test_mixed_precision.py @@ -2,6 +2,7 @@ import unittest import torch import sys import os +import json # Add comfy to path sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) @@ -15,6 +16,7 @@ if not has_gpu(): from comfy import ops from comfy.quant_ops import QuantizedTensor +import comfy.utils class SimpleModel(torch.nn.Module): @@ -94,8 +96,9 @@ class TestMixedPrecisionOps(unittest.TestCase): "layer3.weight_scale": torch.tensor(1.5, dtype=torch.float32), } + state_dict, _ = comfy.utils.convert_old_quants(state_dict, metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})}) # Create model and load state dict (strict=False because custom loading pops keys) - model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config)) + model = SimpleModel(operations=ops.mixed_precision_ops({})) model.load_state_dict(state_dict, strict=False) # Verify weights are wrapped in QuantizedTensor @@ -115,7 +118,8 @@ class TestMixedPrecisionOps(unittest.TestCase): # Forward pass input_tensor = torch.randn(5, 10, dtype=torch.bfloat16) - output = model(input_tensor) + with torch.inference_mode(): + output = model(input_tensor) self.assertEqual(output.shape, (5, 40)) @@ -141,7 +145,8 @@ class TestMixedPrecisionOps(unittest.TestCase): "layer3.bias": torch.randn(40, dtype=torch.bfloat16), } - model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config)) + state_dict1, _ = comfy.utils.convert_old_quants(state_dict1, metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})}) + model = SimpleModel(operations=ops.mixed_precision_ops({})) model.load_state_dict(state_dict1, strict=False) # Save state dict @@ -178,7 +183,8 @@ class TestMixedPrecisionOps(unittest.TestCase): "layer3.bias": torch.randn(40, dtype=torch.bfloat16), } - model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config)) + state_dict, _ = comfy.utils.convert_old_quants(state_dict, metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})}) + model = SimpleModel(operations=ops.mixed_precision_ops({})) model.load_state_dict(state_dict, strict=False) # Add a weight function (simulating LoRA) @@ -215,8 +221,10 @@ class TestMixedPrecisionOps(unittest.TestCase): "layer3.bias": torch.randn(40, dtype=torch.bfloat16), } + state_dict, _ = comfy.utils.convert_old_quants(state_dict, metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})}) + # Load should raise KeyError for unknown format in QUANT_FORMAT_MIXINS - model = SimpleModel(operations=ops.mixed_precision_ops(layer_quant_config)) + model = SimpleModel(operations=ops.mixed_precision_ops({})) with self.assertRaises(KeyError): model.load_state_dict(state_dict, strict=False)