SR model fixes

This also still needs timesteps scheduling based on the noise scale, can be used with two samplers too already
This commit is contained in:
kijai 2025-11-19 01:13:11 +02:00 committed by comfyanonymous
parent 8a45f558a9
commit 0aa6eb2edc

View File

@ -1597,10 +1597,32 @@ class HunyuanVideo15(HunyuanVideo):
return out return out
class HunyuanVideo15_SR_Distilled(HunyuanImage21Refiner): class HunyuanVideo15_SR_Distilled(HunyuanVideo15):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None): def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device) super().__init__(model_config, model_type, device=device)
def concat_cond(self, **kwargs):
noise = kwargs.get("noise", None)
image = kwargs.get("concat_latent_image", None)
noise_augmentation = kwargs.get("noise_augmentation", 0.0)
device = kwargs["device"]
if image is None:
image = torch.zeros([noise.shape[0], noise.shape[1] * 2 + 2, noise.shape[-3], noise.shape[-2], noise.shape[-1]], device=comfy.model_management.intermediate_device())
else:
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
#image = self.process_latent_in(image) # scaling wasn't applied in reference code
image = utils.resize_to_batch_size(image, noise.shape[0])
if noise_augmentation > 0:
generator = torch.Generator(device="cpu")
generator.manual_seed(kwargs.get("seed", 0) - 10)
lq_image_slice = slice(noise.shape[1] + 1, 2 * noise.shape[1] + 1)
noise = torch.randn(image[:, lq_image_slice].shape, generator=generator, dtype=image.dtype, device="cpu").to(image.device)
image[:, lq_image_slice] = noise_augmentation * noise + min(1.0 - noise_augmentation, 0.75) * image[:, lq_image_slice]
else:
image[:, lq_image_slice] = 0.75 * image[:, lq_image_slice]
return image
def extra_conds(self, **kwargs): def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs) out = super().extra_conds(**kwargs)
out['disable_time_r'] = comfy.conds.CONDConstant(False) out['disable_time_r'] = comfy.conds.CONDConstant(False)