neg condition was passed as context

This commit is contained in:
Yousef R. Gamaleldin 2026-01-16 22:17:33 +02:00
parent bc4fd2cd11
commit c78bfda132
3 changed files with 49 additions and 46 deletions

View File

@ -1353,7 +1353,8 @@ class NaDiT(nn.Module):
pos_cond, neg_cond = pos_cond.squeeze(0), neg_cond.squeeze(0) pos_cond, neg_cond = pos_cond.squeeze(0), neg_cond.squeeze(0)
txt, txt_shape = flatten([pos_cond, neg_cond]) txt, txt_shape = flatten([pos_cond, neg_cond])
except: except:
txt, txt_shape = flatten([context.squeeze(0)]) context = self.positive_conditioning
txt, txt_shape = flatten([context])
vid, vid_shape = flatten(x) vid, vid_shape = flatten(x)
cond_latent, _ = flatten(conditions) cond_latent, _ = flatten(conditions)

View File

@ -2140,10 +2140,11 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
input = input.to(x.device) input = input.to(x.device)
o_h, o_w = self.img_dims o_h, o_w = self.img_dims
x = x[..., :o_h, :o_w]
input = input[..., :o_h, :o_w ]
x = lab_color_transfer(x, input) x = lab_color_transfer(x, input)
x = x.unsqueeze(0) x = x.unsqueeze(0)
x = x[..., :o_h, :o_w]
x = rearrange(x, "b t c h w -> b c t h w") x = rearrange(x, "b t c h w -> b c t h w")
# ensure even dims for save video # ensure even dims for save video
@ -2154,8 +2155,12 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
return x return x
def set_memory_limit(self, conv_max_mem: Optional[float], norm_max_mem: Optional[float]): def set_memory_limit(self, conv_max_mem: Optional[float], norm_max_mem: Optional[float], memory_device = "same"):
set_norm_limit(norm_max_mem) set_norm_limit(norm_max_mem)
for m in self.modules(): for m in self.modules():
if isinstance(m, InflatedCausalConv3d): if isinstance(m, InflatedCausalConv3d):
m.set_memory_limit(conv_max_mem if conv_max_mem is not None else float("inf")) m.set_memory_limit(conv_max_mem if conv_max_mem is not None else float("inf"))
for module in self.modules():
if isinstance(module, InflatedCausalConv3d):
module.set_memory_device(memory_device)

View File

@ -1307,9 +1307,6 @@ class SeedVR2(supported_models_base.BASE):
unet_config = { unet_config = {
"image_model": "seedvr2" "image_model": "seedvr2"
} }
sampling_settings = {
"shift": 1.0,
}
latent_format = comfy.latent_formats.SeedVR2 latent_format = comfy.latent_formats.SeedVR2
vae_key_prefix = ["vae."] vae_key_prefix = ["vae."]