Fix more nodes

This commit is contained in:
doctorpangloss 2024-01-03 14:29:16 -08:00
parent 58f8c7486d
commit bf42e687fd
2 changed files with 5 additions and 17 deletions

View File

@ -19,7 +19,7 @@ class UpscaleModelLoader:
model_path = folder_paths.get_full_path("upscale_models", model_name) model_path = folder_paths.get_full_path("upscale_models", model_name)
sd = utils.load_torch_file(model_path, safe_load=True) sd = utils.load_torch_file(model_path, safe_load=True)
if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd: if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd:
sd = comfy.utils.state_dict_prefix_replace(sd, {"module.":""}) sd = utils.state_dict_prefix_replace(sd, {"module.":""})
out = model_loading.load_state_dict(sd).eval() out = model_loading.load_state_dict(sd).eval()
return (out, ) return (out, )

View File

@ -28,23 +28,16 @@ version = '0.0.1'
""" """
The package index to the torch built with AMD ROCm. The package index to the torch built with AMD ROCm.
""" """
amd_torch_index = "https://download.pytorch.org/whl/rocm5.4.2" amd_torch_index = "https://download.pytorch.org/whl/rocm5.6"
""" """
The package index to torch built with CUDA. The package index to torch built with CUDA.
Observe the CUDA version is in this URL. Observe the CUDA version is in this URL.
""" """
nvidia_torch_index = "https://download.pytorch.org/whl/cu118" nvidia_torch_index = "https://download.pytorch.org/whl/cu121"
""" """
The package index to torch built against CPU features. The package index to torch built against CPU features.
This includes macOS MPS support.
"""
cpu_torch_index_nightlies = "https://download.pytorch.org/whl/nightly/cpu"
"""
The package index to torch built against CPU features.
Non-nightlies are selected when building Linux on arm64.
""" """
cpu_torch_index = "https://download.pytorch.org/whl/cpu" cpu_torch_index = "https://download.pytorch.org/whl/cpu"
@ -110,22 +103,17 @@ def _is_linux_arm64():
def dependencies() -> List[str]: def dependencies() -> List[str]:
_dependencies = open(os.path.join(os.path.dirname(__file__), "requirements.txt")).readlines() _dependencies = open(os.path.join(os.path.dirname(__file__), "requirements.txt")).readlines()
# todo: also add all plugin dependencies # todo: also add all plugin dependencies
_alternative_indices = [amd_torch_index, nvidia_torch_index, cpu_torch_index_nightlies] _alternative_indices = [amd_torch_index, nvidia_torch_index]
session = PipSession() session = PipSession()
gpu_accelerated = False
index_urls = ['https://pypi.org/simple'] index_urls = ['https://pypi.org/simple']
# prefer nvidia over AMD because AM5/iGPU systems will have a valid ROCm device # prefer nvidia over AMD because AM5/iGPU systems will have a valid ROCm device
if _is_nvidia(): if _is_nvidia():
index_urls += [nvidia_torch_index] index_urls += [nvidia_torch_index]
gpu_accelerated = True
elif _is_amd(): elif _is_amd():
index_urls += [amd_torch_index] index_urls += [amd_torch_index]
gpu_accelerated = True
elif _is_linux_arm64():
index_urls += [cpu_torch_index]
else: else:
index_urls += [cpu_torch_index_nightlies] index_urls += [cpu_torch_index]
if len(index_urls) == 1: if len(index_urls) == 1:
return _dependencies return _dependencies