diff --git a/comfy/ops.py b/comfy/ops.py index b5cd1d47e..b33fde1aa 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -1159,6 +1159,93 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec self._buffers[key] = fn(buf) return self + class Embedding(manual_cast.Embedding): + def _load_from_state_dict(self, state_dict, prefix, local_metadata, + strict, missing_keys, unexpected_keys, error_msgs): + weight_key = f"{prefix}weight" + layer_conf = state_dict.pop(f"{prefix}comfy_quant", None) + if layer_conf is not None: + layer_conf = json.loads(layer_conf.numpy().tobytes()) + + # Only fp8 makes sense for embeddings (per-row dequant via index select). + # Block-scaled formats (NVFP4, MXFP8) can't do per-row lookup efficiently. + quant_format = layer_conf.get("format", None) if layer_conf is not None else None + if quant_format in ["float8_e4m3fn", "float8_e5m2"] and weight_key in state_dict: + self.quant_format = quant_format + qconfig = QUANT_ALGOS[quant_format] + layout_cls = get_layout_class(qconfig["comfy_tensor_layout"]) + weight = state_dict.pop(weight_key) + manually_loaded_keys = [weight_key] + + scale_key = f"{prefix}weight_scale" + scale = state_dict.pop(scale_key, None) + if scale is not None: + scale = scale.float() + manually_loaded_keys.append(scale_key) + + params = layout_cls.Params( + scale=scale if scale is not None else torch.ones((), dtype=torch.float32), + orig_dtype=MixedPrecisionOps._compute_dtype, + orig_shape=(self.num_embeddings, self.embedding_dim), + ) + self.weight = torch.nn.Parameter( + QuantizedTensor(weight.to(dtype=qconfig["storage_t"]), qconfig["comfy_tensor_layout"], params), + requires_grad=False) + + super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + for k in manually_loaded_keys: + if k in missing_keys: + missing_keys.remove(k) + else: + if layer_conf is not None: + state_dict[f"{prefix}comfy_quant"] = torch.tensor(list(json.dumps(layer_conf).encode('utf-8')), dtype=torch.uint8) + super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) + + def state_dict(self, *args, destination=None, prefix="", **kwargs): + if destination is not None: + sd = destination + else: + sd = {} + + if not hasattr(self, 'weight') or self.weight is None: + return sd + + if isinstance(self.weight, QuantizedTensor): + sd_out = self.weight.state_dict("{}weight".format(prefix)) + for k in sd_out: + sd[k] = sd_out[k] + + quant_conf = {"format": self.quant_format} + sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8) + else: + sd["{}weight".format(prefix)] = self.weight + return sd + + def forward_comfy_cast_weights(self, input, out_dtype=None): + weight = self.weight + + # Optimized path: lookup in fp8, dequantize only the selected rows. + if isinstance(weight, QuantizedTensor) and len(self.weight_function) == 0: + qdata, _, offload_stream = cast_bias_weight(self, device=input.device, dtype=weight.dtype, offloadable=True) + if isinstance(qdata, QuantizedTensor): + scale = qdata._params.scale + qdata = qdata._qdata + else: + scale = None + + x = torch.nn.functional.embedding( + input, qdata, self.padding_idx, self.max_norm, + self.norm_type, self.scale_grad_by_freq, self.sparse) + uncast_bias_weight(self, qdata, None, offload_stream) + target_dtype = out_dtype if out_dtype is not None else weight.params.orig_dtype + x = x.to(dtype=target_dtype) + if scale is not None and scale != 1.0: + x = x * scale.to(dtype=target_dtype) + return x + + # Fallback for non-quantized or weight_function (LoRA) case + return super().forward_comfy_cast_weights(input, out_dtype=out_dtype) + return MixedPrecisionOps def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):