improvements

This commit is contained in:
Yousef Rafat 2025-12-20 23:20:45 +02:00
parent 0d2044a778
commit 4fe772fae9
4 changed files with 23 additions and 8 deletions

View File

@ -1377,7 +1377,7 @@ class NaDiT(nn.Module):
out = torch.stack(vid)
try:
pos, neg = out.chunk(2)
ut = torch.cat([neg, pos])
out = torch.cat([neg, pos])
out = out.movedim(-1, 1)
return out
except:

View File

@ -1541,9 +1541,10 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
x = self.decode(z).sample
return x, z, p
def encode(self, x: torch.FloatTensor):
def encode(self, x, orig_dims):
# we need to keep a reference to the image/video so we later can do a colour fix later
self.original_image_video = x
self.img_dims = orig_dims
if x.ndim == 4:
x = x.unsqueeze(2)
x = x.to(next(self.parameters()).dtype)
@ -1570,6 +1571,8 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
input = rearrange(self.original_image_video[0], "c t h w -> t c h w")
x = wavelet_reconstruction(x, input)
o_h, o_w = self.img_dims
x = x[..., :o_h, :o_w]
return x
def set_memory_limit(self, conv_max_mem: Optional[float], norm_max_mem: Optional[float]):

View File

@ -386,6 +386,8 @@ class VAE:
self.downscale_index_formula = (4, 8, 8)
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
self.upscale_index_formula = (4, 8, 8)
self.process_input = lambda image: image
self.crop_input = False
elif "decoder.conv_in.weight" in sd:
if sd['decoder.conv_in.weight'].shape[1] == 64:
ddconfig = {"block_out_channels": [128, 256, 512, 512, 1024, 1024], "in_channels": 3, "out_channels": 3, "num_res_blocks": 2, "ffactor_spatial": 32, "downsample_match_channel": True, "upsample_match_channel": True}

View File

@ -68,14 +68,21 @@ def area_resize(image, max_area):
interpolation=InterpolationMode.BICUBIC,
)
def crop(image, factor):
def div_pad(image, factor):
height_factor, width_factor = factor
height, width = image.shape[-2:]
cropped_height = height - (height % height_factor)
cropped_width = width - (width % width_factor)
pad_height = (height_factor - (height % height_factor)) % height_factor
pad_width = (width_factor - (width % width_factor)) % width_factor
if pad_height == 0 and pad_width == 0:
return image
if isinstance(image, torch.Tensor):
padding = (0, pad_width, 0, pad_height)
image = torch.nn.functional.pad(image, padding, mode='constant', value=0.0)
image = TVF.center_crop(img=image, output_size=(cropped_height, cropped_width))
return image
def cut_videos(videos):
@ -120,6 +127,8 @@ class SeedVR2InputProcessing(io.ComfyNode):
device = vae.patcher.load_device
offload_device = comfy.model_management.intermediate_device()
main_device = comfy.model_management.get_torch_device()
images = images.to(main_device)
vae_model = vae.first_stage_model
scale = 0.9152; shift = 0
if images.dim() != 5: # add the t dim
@ -135,7 +144,8 @@ class SeedVR2InputProcessing(io.ComfyNode):
images = area_resize(images, max_area)
images = clip(images)
images = crop(images, (16, 16))
o_h, o_w = images.shape[-2:]
images = div_pad(images, (16, 16))
images = normalize(images)
_, _, new_h, new_w = images.shape
@ -145,7 +155,7 @@ class SeedVR2InputProcessing(io.ComfyNode):
images = rearrange(images, "b t c h w -> b c t h w")
images = images.to(device)
vae_model = vae_model.to(device)
latent = vae_model.encode(images)[0]
latent = vae_model.encode(images, [o_h, o_w])[0]
vae_model = vae_model.to(offload_device)
latent = latent.unsqueeze(2) if latent.ndim == 4 else latent