From 413ee3f687a5ec9d287e58a73379a28e1eabb44a Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Wed, 10 Dec 2025 22:58:53 +0200 Subject: [PATCH] . --- comfy/ldm/seedvr/model.py | 10 +++++++--- comfy/ldm/seedvr/vae.py | 22 ++++++++++++++++++++-- comfy/model_base.py | 6 +++++- comfy_extras/nodes_seedvr.py | 16 ++++++++++------ 4 files changed, 42 insertions(+), 12 deletions(-) diff --git a/comfy/ldm/seedvr/model.py b/comfy/ldm/seedvr/model.py index 42567fa30..98121f26f 100644 --- a/comfy/ldm/seedvr/model.py +++ b/comfy/ldm/seedvr/model.py @@ -1300,8 +1300,7 @@ class NaDiT(nn.Module): **kwargs ): transformer_options = kwargs.get("transformer_options", {}) - c_or_u_list = transformer_options.get("cond_or_uncond", []) - cond_latent = c_or_u_list[0]["condition"] + conditions = kwargs.get("condition") pos_cond, neg_cond = context.chunk(2, dim=0) # txt_shape should be the same for both @@ -1312,11 +1311,16 @@ class NaDiT(nn.Module): vid = x vid, vid_shape = flatten(x) + cond_latent, _ = flatten(conditions) - vid = torch.cat([cond_latent, vid]) + vid = torch.cat([cond_latent, vid], dim=-1) if txt_shape.size(-1) == 1 and self.need_txt_repeat: txt, txt_shape = repeat(txt, txt_shape, "l c -> t l c", t=vid_shape[:, 0]) + device = next(self.parameters()).device + dtype = next(self.parameters()).dtype + txt = txt.to(device).to(dtype) + vid = vid.to(device).to(dtype) txt = self.txt_in(txt) vid, vid_shape = self.vid_in(vid, vid_shape) diff --git a/comfy/ldm/seedvr/vae.py b/comfy/ldm/seedvr/vae.py index 4a503dde4..1086f9adc 100644 --- a/comfy/ldm/seedvr/vae.py +++ b/comfy/ldm/seedvr/vae.py @@ -330,6 +330,17 @@ def safe_interpolate_operation(x, size=None, scale_factor=None, mode='nearest', _receptive_field_t = Literal["half", "full"] +def extend_head(tensor, times: int = 2, memory = None): + if memory is not None: + return torch.cat((memory.to(tensor), tensor), dim=2) + assert times >= 0, "Invalid input for function 'extend_head'!" + if times == 0: + return tensor + else: + tile_repeat = [1] * tensor.ndim + tile_repeat[2] = times + return torch.cat(tensors=(torch.tile(tensor[:, :, :1], tile_repeat), tensor), dim=2) + class InflatedCausalConv3d(nn.Conv3d): def __init__( self, @@ -348,6 +359,7 @@ class InflatedCausalConv3d(nn.Conv3d): self, input, ): + input = extend_head(input, times=self.temporal_padding * 2) return super().forward(input) def _load_from_state_dict( @@ -514,6 +526,8 @@ class Downsample3D(nn.Module): self.out_channels = out_channels or channels self.temporal_down = temporal_down self.spatial_down = spatial_down + self.use_conv = use_conv + self.padding = padding self.temporal_ratio = 2 if temporal_down else 1 self.spatial_ratio = 2 if spatial_down else 1 @@ -630,6 +644,7 @@ class ResnetBlock3D(nn.Module): inflation_mode=inflation_mode, ) + self.upsample = self.downsample = None if self.up: self.upsample = Upsample3D( self.in_channels, @@ -646,6 +661,7 @@ class ResnetBlock3D(nn.Module): inflation_mode=inflation_mode, ) + self.conv_shortcut = None if self.use_in_shortcut: self.conv_shortcut = InflatedCausalConv3d( self.in_channels, @@ -1093,6 +1109,7 @@ class Encoder3D(nn.Module): extra_cond=None, ) -> torch.FloatTensor: r"""The forward method of the `Encoder` class.""" + sample = sample.to(next(self.parameters()).device) sample = self.conv_in(sample) if self.training and self.gradient_checkpointing: @@ -1450,8 +1467,9 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL): if x.ndim == 4: x = x.unsqueeze(2) x = x.to(next(self.parameters()).dtype) - p = super().encode(x).latent_dist - z = p.sample().squeeze(2) + x = x.to(next(self.parameters()).device) + p = super().encode(x) + z = p.squeeze(2) return z, p def decode(self, z: torch.FloatTensor): diff --git a/comfy/model_base.py b/comfy/model_base.py index bbab8627a..f9cc26bfb 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -798,7 +798,11 @@ class HunyuanDiT(BaseModel): class SeedVR2(BaseModel): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super().__init__(model_config, model_type, device, comfy.ldm.seedvr.model.NaDiT) - # TODO: extra_conds could be needed to add + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + condition = kwargs.get("condition", None) + out["condition"] = comfy.conds.CONDRegular(condition) + return out class PixArt(BaseModel): def __init__(self, model_config, model_type=ModelType.EPS, device=None): diff --git a/comfy_extras/nodes_seedvr.py b/comfy_extras/nodes_seedvr.py index e83e37c1d..9e8429b66 100644 --- a/comfy_extras/nodes_seedvr.py +++ b/comfy_extras/nodes_seedvr.py @@ -39,7 +39,7 @@ def timestep_transform(timesteps, latents_shapes): frames > 1, vid_shift_fn(heights * widths * frames), img_shift_fn(heights * widths), - ) + ).to(timesteps.device) # Shift timesteps. T = 1000.0 @@ -116,6 +116,7 @@ class SeedVR2InputProcessing(io.ComfyNode): @classmethod def execute(cls, images, vae, resolution_height, resolution_width): + device = vae.patcher.load_device vae = vae.first_stage_model scale = 0.9152; shift = 0 @@ -140,6 +141,8 @@ class SeedVR2InputProcessing(io.ComfyNode): images = cut_videos(images) images = rearrange(images, "b t c h w -> b c t h w") + vae = vae.to(device) + images = images.to(device) latent = vae.encode(images)[0] latent = (latent - shift) * scale @@ -166,24 +169,25 @@ class SeedVR2Conditioning(io.ComfyNode): def execute(cls, text_positive_conditioning, text_negative_conditioning, vae_conditioning) -> io.NodeOutput: vae_conditioning = vae_conditioning["samples"] + device = vae_conditioning.device pos_cond = text_positive_conditioning[0][0] neg_cond = text_negative_conditioning[0][0] - noises = torch.randn_like(vae_conditioning) - aug_noises = torch.randn_like(vae_conditioning) + noises = torch.randn_like(vae_conditioning).to(device) + aug_noises = torch.randn_like(vae_conditioning).to(device) cond_noise_scale = 0.0 t = ( torch.tensor([1000.0]) * cond_noise_scale - ) - shape = torch.tensor(vae_conditioning.shape[1:])[None] + ).to(device) + shape = torch.tensor(vae_conditioning.shape[1:]).to(device)[None] t = timestep_transform(t, shape) cond = inter(vae_conditioning, aug_noises, t) condition = get_conditions(noises, cond) pos_shape = pos_cond.shape[1] - neg_shape = neg_shape.shape[1] + neg_shape = neg_cond.shape[1] diff = abs(pos_shape - neg_shape) if pos_shape > neg_shape: neg_cond = F.pad(neg_cond, (0, 0, 0, diff))