mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-27 23:00:20 +08:00
Compare commits
No commits in common. "2f642d5d9b48ad7cad13bbdd5f8adcf506f565a7" and "dc202a2e51bf7a6cd00e606b2d2941bc223f2ad2" have entirely different histories.
2f642d5d9b
...
dc202a2e51
@ -237,8 +237,6 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
else:
|
||||
dit_config["vec_in_dim"] = None
|
||||
|
||||
dit_config["num_heads"] = dit_config["hidden_size"] // sum(dit_config["axes_dim"])
|
||||
|
||||
dit_config["depth"] = count_blocks(state_dict_keys, '{}double_blocks.'.format(key_prefix) + '{}.')
|
||||
dit_config["depth_single_blocks"] = count_blocks(state_dict_keys, '{}single_blocks.'.format(key_prefix) + '{}.')
|
||||
if '{}distilled_guidance_layer.0.norms.0.scale'.format(key_prefix) in state_dict_keys or '{}distilled_guidance_layer.norms.0.scale'.format(key_prefix) in state_dict_keys: #Chroma
|
||||
|
||||
@ -368,7 +368,7 @@ try:
|
||||
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
|
||||
ENABLE_PYTORCH_ATTENTION = True
|
||||
if rocm_version >= (7, 0):
|
||||
if any((a in arch) for a in ["gfx1200", "gfx1201"]):
|
||||
if any((a in arch) for a in ["gfx1201"]):
|
||||
ENABLE_PYTORCH_ATTENTION = True
|
||||
if torch_version_numeric >= (2, 7) and rocm_version >= (6, 4):
|
||||
if any((a in arch) for a in ["gfx1200", "gfx1201", "gfx950"]): # TODO: more arches, "gfx942" gives error on pytorch nightly 2.10 1013 rocm7.0
|
||||
|
||||
@ -1059,9 +1059,9 @@ def detect_te_model(sd):
|
||||
return TEModel.JINA_CLIP_2
|
||||
if "encoder.block.23.layer.1.DenseReluDense.wi_1.weight" in sd:
|
||||
weight = sd["encoder.block.23.layer.1.DenseReluDense.wi_1.weight"]
|
||||
if weight.shape[0] == 10240:
|
||||
if weight.shape[-1] == 4096:
|
||||
return TEModel.T5_XXL
|
||||
elif weight.shape[0] == 5120:
|
||||
elif weight.shape[-1] == 2048:
|
||||
return TEModel.T5_XL
|
||||
if 'encoder.block.23.layer.1.DenseReluDense.wi.weight' in sd:
|
||||
return TEModel.T5_XXL_OLD
|
||||
|
||||
@ -36,7 +36,7 @@ def te(dtype_t5=None, t5_quantization_metadata=None):
|
||||
if t5_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
|
||||
if dtype_t5 is not None:
|
||||
if dtype is None:
|
||||
dtype = dtype_t5
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
return CosmosTEModel_
|
||||
|
||||
@ -32,7 +32,7 @@ def mochi_te(dtype_t5=None, t5_quantization_metadata=None):
|
||||
if t5_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
|
||||
if dtype_t5 is not None:
|
||||
if dtype is None:
|
||||
dtype = dtype_t5
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
return MochiTEModel_
|
||||
|
||||
@ -36,7 +36,7 @@ def pixart_te(dtype_t5=None, t5_quantization_metadata=None):
|
||||
if t5_quantization_metadata is not None:
|
||||
model_options = model_options.copy()
|
||||
model_options["t5xxl_quantization_metadata"] = t5_quantization_metadata
|
||||
if dtype_t5 is not None:
|
||||
if dtype is None:
|
||||
dtype = dtype_t5
|
||||
super().__init__(device=device, dtype=dtype, model_options=model_options)
|
||||
return PixArtTEModel_
|
||||
|
||||
Loading…
Reference in New Issue
Block a user