Merge upstream

This commit is contained in:
doctorpangloss 2024-02-21 13:33:05 -08:00
commit bf6d91fec0
4 changed files with 40 additions and 8 deletions

View File

@ -37,11 +37,41 @@ class SDXL(LatentFormat):
class SD_X4(LatentFormat):
def __init__(self):
self.scale_factor = 0.08333
self.latent_rgb_factors = [
[-0.2340, -0.3863, -0.3257],
[ 0.0994, 0.0885, -0.0908],
[-0.2833, -0.2349, -0.3741],
[ 0.2523, -0.0055, -0.1651]
]
class SC_Prior(LatentFormat):
def __init__(self):
self.scale_factor = 1.0
self.latent_rgb_factors = [
[-0.0326, -0.0204, -0.0127],
[-0.1592, -0.0427, 0.0216],
[ 0.0873, 0.0638, -0.0020],
[-0.0602, 0.0442, 0.1304],
[ 0.0800, -0.0313, -0.1796],
[-0.0810, -0.0638, -0.1581],
[ 0.1791, 0.1180, 0.0967],
[ 0.0740, 0.1416, 0.0432],
[-0.1745, -0.1888, -0.1373],
[ 0.2412, 0.1577, 0.0928],
[ 0.1908, 0.0998, 0.0682],
[ 0.0209, 0.0365, -0.0092],
[ 0.0448, -0.0650, -0.1728],
[-0.1658, -0.1045, -0.1308],
[ 0.0542, 0.1545, 0.1325],
[-0.0352, -0.1672, -0.2541]
]
class SC_B(LatentFormat):
def __init__(self):
self.scale_factor = 1.0
self.latent_rgb_factors = [
[ 0.1121, 0.2006, 0.1023],
[-0.2093, -0.0222, -0.0195],
[-0.3087, -0.1535, 0.0366],
[ 0.0290, -0.1574, -0.4078]
]

View File

@ -473,15 +473,11 @@ class StableCascade_B(BaseModel):
clip_text_pooled = kwargs["pooled_output"]
if clip_text_pooled is not None:
out['clip_text_pooled'] = conds.CONDRegular(clip_text_pooled)
out['clip'] = conds.CONDRegular(clip_text_pooled)
#size of prior doesn't really matter if zeros because it gets resized but I still want it to get batched
prior = kwargs.get("stable_cascade_prior", torch.zeros((1, 16, (noise.shape[2] * 4) // 42, (noise.shape[3] * 4) // 42), dtype=noise.dtype, layout=noise.layout, device=noise.device))
out["effnet"] = conds.CONDRegular(prior)
out["sca"] = conds.CONDRegular(torch.zeros((1,)))
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
out['clip'] = conds.CONDCrossAttn(cross_attn)
return out

View File

@ -150,10 +150,10 @@ class StableCascadeSampling(ModelSamplingDiscrete):
self._init_alpha_cumprod = torch.cos(self.cosine_s / (1 + self.cosine_s) * torch.pi * 0.5) ** 2
#This part is just for compatibility with some schedulers in the codebase
self.num_timesteps = 1000
self.num_timesteps = 10000
sigmas = torch.empty((self.num_timesteps), dtype=torch.float32)
for x in range(self.num_timesteps):
t = x / self.num_timesteps
t = (x + 1) / self.num_timesteps
sigmas[x] = self.sigma(t)
self.set_sigmas(sigmas)

View File

@ -16,6 +16,10 @@ class LCM(comfy.model_sampling.EPS):
return c_out * x0 + c_skip * model_input
class X0(comfy.model_sampling.EPS):
def calculate_denoised(self, sigma, model_output, model_input):
return model_output
class ModelSamplingDiscreteDistilled(comfy.model_sampling.ModelSamplingDiscrete):
original_timesteps = 50
@ -67,7 +71,7 @@ class ModelSamplingDiscrete:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"sampling": (["eps", "v_prediction", "lcm"],),
"sampling": (["eps", "v_prediction", "lcm", "x0"],),
"zsnr": ("BOOLEAN", {"default": False}),
}}
@ -87,6 +91,8 @@ class ModelSamplingDiscrete:
elif sampling == "lcm":
sampling_type = LCM
sampling_base = ModelSamplingDiscreteDistilled
elif sampling == "x0":
sampling_type = X0
class ModelSamplingAdvanced(sampling_base, sampling_type):
pass