mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-27 10:52:31 +08:00
Add fp8 scaled embedding support
This commit is contained in:
parent
0fc398a821
commit
ba3a484c06
87
comfy/ops.py
87
comfy/ops.py
@ -1159,6 +1159,93 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
self._buffers[key] = fn(buf)
|
self._buffers[key] = fn(buf)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
class Embedding(manual_cast.Embedding):
|
||||||
|
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
||||||
|
strict, missing_keys, unexpected_keys, error_msgs):
|
||||||
|
weight_key = f"{prefix}weight"
|
||||||
|
layer_conf = state_dict.pop(f"{prefix}comfy_quant", None)
|
||||||
|
if layer_conf is not None:
|
||||||
|
layer_conf = json.loads(layer_conf.numpy().tobytes())
|
||||||
|
|
||||||
|
# Only fp8 makes sense for embeddings (per-row dequant via index select).
|
||||||
|
# Block-scaled formats (NVFP4, MXFP8) can't do per-row lookup efficiently.
|
||||||
|
quant_format = layer_conf.get("format", None) if layer_conf is not None else None
|
||||||
|
if quant_format in ["float8_e4m3fn", "float8_e5m2"] and weight_key in state_dict:
|
||||||
|
self.quant_format = quant_format
|
||||||
|
qconfig = QUANT_ALGOS[quant_format]
|
||||||
|
layout_cls = get_layout_class(qconfig["comfy_tensor_layout"])
|
||||||
|
weight = state_dict.pop(weight_key)
|
||||||
|
manually_loaded_keys = [weight_key]
|
||||||
|
|
||||||
|
scale_key = f"{prefix}weight_scale"
|
||||||
|
scale = state_dict.pop(scale_key, None)
|
||||||
|
if scale is not None:
|
||||||
|
scale = scale.float()
|
||||||
|
manually_loaded_keys.append(scale_key)
|
||||||
|
|
||||||
|
params = layout_cls.Params(
|
||||||
|
scale=scale if scale is not None else torch.ones((), dtype=torch.float32),
|
||||||
|
orig_dtype=MixedPrecisionOps._compute_dtype,
|
||||||
|
orig_shape=(self.num_embeddings, self.embedding_dim),
|
||||||
|
)
|
||||||
|
self.weight = torch.nn.Parameter(
|
||||||
|
QuantizedTensor(weight.to(dtype=qconfig["storage_t"]), qconfig["comfy_tensor_layout"], params),
|
||||||
|
requires_grad=False)
|
||||||
|
|
||||||
|
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||||
|
for k in manually_loaded_keys:
|
||||||
|
if k in missing_keys:
|
||||||
|
missing_keys.remove(k)
|
||||||
|
else:
|
||||||
|
if layer_conf is not None:
|
||||||
|
state_dict[f"{prefix}comfy_quant"] = torch.tensor(list(json.dumps(layer_conf).encode('utf-8')), dtype=torch.uint8)
|
||||||
|
super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||||
|
|
||||||
|
def state_dict(self, *args, destination=None, prefix="", **kwargs):
|
||||||
|
if destination is not None:
|
||||||
|
sd = destination
|
||||||
|
else:
|
||||||
|
sd = {}
|
||||||
|
|
||||||
|
if not hasattr(self, 'weight') or self.weight is None:
|
||||||
|
return sd
|
||||||
|
|
||||||
|
if isinstance(self.weight, QuantizedTensor):
|
||||||
|
sd_out = self.weight.state_dict("{}weight".format(prefix))
|
||||||
|
for k in sd_out:
|
||||||
|
sd[k] = sd_out[k]
|
||||||
|
|
||||||
|
quant_conf = {"format": self.quant_format}
|
||||||
|
sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
|
||||||
|
else:
|
||||||
|
sd["{}weight".format(prefix)] = self.weight
|
||||||
|
return sd
|
||||||
|
|
||||||
|
def forward_comfy_cast_weights(self, input, out_dtype=None):
|
||||||
|
weight = self.weight
|
||||||
|
|
||||||
|
# Optimized path: lookup in fp8, dequantize only the selected rows.
|
||||||
|
if isinstance(weight, QuantizedTensor) and len(self.weight_function) == 0:
|
||||||
|
qdata, _, offload_stream = cast_bias_weight(self, device=input.device, dtype=weight.dtype, offloadable=True)
|
||||||
|
if isinstance(qdata, QuantizedTensor):
|
||||||
|
scale = qdata._params.scale
|
||||||
|
qdata = qdata._qdata
|
||||||
|
else:
|
||||||
|
scale = None
|
||||||
|
|
||||||
|
x = torch.nn.functional.embedding(
|
||||||
|
input, qdata, self.padding_idx, self.max_norm,
|
||||||
|
self.norm_type, self.scale_grad_by_freq, self.sparse)
|
||||||
|
uncast_bias_weight(self, qdata, None, offload_stream)
|
||||||
|
target_dtype = out_dtype if out_dtype is not None else weight.params.orig_dtype
|
||||||
|
x = x.to(dtype=target_dtype)
|
||||||
|
if scale is not None and scale != 1.0:
|
||||||
|
x = x * scale.to(dtype=target_dtype)
|
||||||
|
return x
|
||||||
|
|
||||||
|
# Fallback for non-quantized or weight_function (LoRA) case
|
||||||
|
return super().forward_comfy_cast_weights(input, out_dtype=out_dtype)
|
||||||
|
|
||||||
return MixedPrecisionOps
|
return MixedPrecisionOps
|
||||||
|
|
||||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):
|
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user