Fix #5 TAESD node was using a bad variable name that shadowed a module in a relative import

This commit is contained in:
doctorpangloss 2024-03-29 16:28:13 -07:00
parent bd87697fdf
commit b0ab12bf05

View File

@ -666,7 +666,7 @@ class VAELoader:
@staticmethod
def load_taesd(name):
sd = {}
sd_ = {}
approx_vaes = folder_paths.get_filename_list("vae_approx")
encoder = next(filter(lambda a: a.startswith("{}_encoder.".format(name)), approx_vaes))
@ -674,17 +674,17 @@ class VAELoader:
enc = utils.load_torch_file(folder_paths.get_full_path("vae_approx", encoder))
for k in enc:
sd["taesd_encoder.{}".format(k)] = enc[k]
sd_["taesd_encoder.{}".format(k)] = enc[k]
dec = utils.load_torch_file(folder_paths.get_full_path("vae_approx", decoder))
for k in dec:
sd["taesd_decoder.{}".format(k)] = dec[k]
sd_["taesd_decoder.{}".format(k)] = dec[k]
if name == "taesd":
sd["vae_scale"] = torch.tensor(0.18215)
sd_["vae_scale"] = torch.tensor(0.18215)
elif name == "taesdxl":
sd["vae_scale"] = torch.tensor(0.13025)
return sd
sd_["vae_scale"] = torch.tensor(0.13025)
return sd_
@classmethod
def INPUT_TYPES(s):
@ -697,11 +697,11 @@ class VAELoader:
#TODO: scale factor?
def load_vae(self, vae_name):
if vae_name in ["taesd", "taesdxl"]:
sd = self.load_taesd(vae_name)
sd_ = self.load_taesd(vae_name)
else:
vae_path = folder_paths.get_full_path("vae", vae_name)
sd = utils.load_torch_file(vae_path)
vae = sd.VAE(sd=sd)
sd_ = utils.load_torch_file(vae_path)
vae = sd.VAE(sd=sd_)
return (vae,)
class ControlNetLoader: