diff --git a/comfy/model_base.py b/comfy/model_base.py index bb5ca4da4..60dded117 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -1597,10 +1597,32 @@ class HunyuanVideo15(HunyuanVideo): return out -class HunyuanVideo15_SR_Distilled(HunyuanImage21Refiner): +class HunyuanVideo15_SR_Distilled(HunyuanVideo15): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device=device) + def concat_cond(self, **kwargs): + noise = kwargs.get("noise", None) + image = kwargs.get("concat_latent_image", None) + noise_augmentation = kwargs.get("noise_augmentation", 0.0) + device = kwargs["device"] + + if image is None: + image = torch.zeros([noise.shape[0], noise.shape[1] * 2 + 2, noise.shape[-3], noise.shape[-2], noise.shape[-1]], device=comfy.model_management.intermediate_device()) + else: + image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center") + #image = self.process_latent_in(image) # scaling wasn't applied in reference code + image = utils.resize_to_batch_size(image, noise.shape[0]) + if noise_augmentation > 0: + generator = torch.Generator(device="cpu") + generator.manual_seed(kwargs.get("seed", 0) - 10) + lq_image_slice = slice(noise.shape[1] + 1, 2 * noise.shape[1] + 1) + noise = torch.randn(image[:, lq_image_slice].shape, generator=generator, dtype=image.dtype, device="cpu").to(image.device) + image[:, lq_image_slice] = noise_augmentation * noise + min(1.0 - noise_augmentation, 0.75) * image[:, lq_image_slice] + else: + image[:, lq_image_slice] = 0.75 * image[:, lq_image_slice] + return image + def extra_conds(self, **kwargs): out = super().extra_conds(**kwargs) out['disable_time_r'] = comfy.conds.CONDConstant(False)