Merge upstream/master, keep local README.md

This commit is contained in:
GitHub Actions 2026-01-11 00:43:00 +00:00
commit 6e634fa5fe
9 changed files with 29 additions and 19 deletions

View File

@ -237,6 +237,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
else:
dit_config["vec_in_dim"] = None
dit_config["num_heads"] = dit_config["hidden_size"] // sum(dit_config["axes_dim"])
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
if '{}distilled_guidance_layer.0.norms.0.scale'.format(key_prefix) in state_dict_keys or '{}distilled_guidance_layer.norms.0.scale'.format(key_prefix) in state_dict_keys: #Chroma

View File

@ -368,7 +368,7 @@ try:
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
ENABLE_PYTORCH_ATTENTION = True
if rocm_version >= (7, 0):
if any((a in arch) for a in ["gfx1201"]):
if any((a in arch) for a in ["gfx1200", "gfx1201"]):
ENABLE_PYTORCH_ATTENTION = True
if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4):
if any((a in arch) for a in ["gfx1200", "gfx1201", "gfx950"]): # TODO: more arches, "gfx942" gives error on pytorch nightly 2.10 1013 rocm7.0

View File

@ -625,21 +625,29 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
missing_keys.remove(key)
def state_dict(self, *args, destination=None, prefix="", **kwargs):
sd = super().state_dict(*args, destination=destination, prefix=prefix, **kwargs)
if isinstance(self.weight, QuantizedTensor):
layout_cls = self.weight._layout_cls
if destination is not None:
sd = destination
else:
sd = {}
# Check if it's any FP8 variant (E4M3 or E5M2)
if layout_cls in ("TensorCoreFP8E4M3Layout", "TensorCoreFP8E5M2Layout", "TensorCoreFP8Layout"):
sd["{}weight_scale".format(prefix)] = self.weight._params.scale
elif layout_cls == "TensorCoreNVFP4Layout":
sd["{}weight_scale_2".format(prefix)] = self.weight._params.scale
sd["{}weight_scale".format(prefix)] = self.weight._params.block_scale
if self.bias is not None:
sd["{}bias".format(prefix)] = self.bias
if isinstance(self.weight, QuantizedTensor):
sd_out = self.weight.state_dict("{}weight".format(prefix))
for k in sd_out:
sd[k] = sd_out[k]
quant_conf = {"format": self.quant_format}
if self._full_precision_mm_config:
quant_conf["full_precision_matrix_mult"] = True
sd["{}comfy_quant".format(prefix)] = torch.tensor(list(json.dumps(quant_conf).encode('utf-8')), dtype=torch.uint8)
input_scale = getattr(self, 'input_scale', None)
if input_scale is not None:
sd["{}input_scale".format(prefix)] = input_scale
else:
sd["{}weight".format(prefix)] = self.weight
return sd
def _forward(self, input, weight, bias):

View File

@ -1059,9 +1059,9 @@ def detect_te_model(sd):
return TEModel.JINA_CLIP_2
if "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in sd:
weight = sd["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"]
if weight.shape[-1] == 4096:
if weight.shape[0] == 10240:
return TEModel.T5_XXL
elif weight.shape[-1] == 2048:
elif weight.shape[0] == 5120:
return TEModel.T5_XL
if 'encoder.block.23.layer.1.DenseReluDense.wi.weight' in sd:
return TEModel.T5_XXL_OLD

View File

@ -36,7 +36,7 @@ def te(dtype_t5=None, t5_quantization_metadata=None):
if t5_quantization_metadata is not None:
model_options = model_options.copy()
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
if dtype is None:
if dtype_t5 is not None:
dtype = dtype_t5
super().__init__(device=device, dtype=dtype, model_options=model_options)
return CosmosTEModel_

View File

@ -32,7 +32,7 @@ def mochi_te(dtype_t5=None, t5_quantization_metadata=None):
if t5_quantization_metadata is not None:
model_options = model_options.copy()
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
if dtype is None:
if dtype_t5 is not None:
dtype = dtype_t5
super().__init__(device=device, dtype=dtype, model_options=model_options)
return MochiTEModel_

View File

@ -36,7 +36,7 @@ def pixart_te(dtype_t5=None, t5_quantization_metadata=None):
if t5_quantization_metadata is not None:
model_options = model_options.copy()
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
if dtype is None:
if dtype_t5 is not None:
dtype = dtype_t5
super().__init__(device=device, dtype=dtype, model_options=model_options)
return PixArtTEModel_

View File

@ -1,6 +1,6 @@
comfyui-frontend-package==1.36.13
comfyui-workflow-templates==0.7.69
comfyui-embedded-docs==0.3.1
comfyui-embedded-docs==0.4.0
torch
torchsde
torchvision

View File

@ -153,9 +153,9 @@ class TestMixedPrecisionOps(unittest.TestCase):
state_dict2 = model.state_dict()
# Verify layer1.weight is a QuantizedTensor with scale preserved
self.assertIsInstance(state_dict2["layer1.weight"], QuantizedTensor)
self.assertEqual(state_dict2["layer1.weight"]._params.scale.item(), 3.0)
self.assertEqual(state_dict2["layer1.weight"]._layout_cls, "TensorCoreFP8E4M3Layout")
self.assertTrue(torch.equal(state_dict2["layer1.weight"].view(torch.uint8), fp8_weight.view(torch.uint8)))
self.assertEqual(state_dict2["layer1.weight_scale"].item(), 3.0)
self.assertEqual(model.layer1.weight._layout_cls, "TensorCoreFP8E4M3Layout")
# Verify non-quantized layers are standard tensors
self.assertNotIsInstance(state_dict2["layer2.weight"], QuantizedTensor)