fixed some issues

This commit is contained in:
Yousef Rafat 2025-12-09 23:54:56 +02:00
parent f030b3afc8
commit d12702ee0b
3 changed files with 44 additions and 24 deletions

View File

@ -1189,7 +1189,6 @@ class Decoder3D(nn.Module):
# up # up
reversed_block_out_channels = list(reversed(block_out_channels)) reversed_block_out_channels = list(reversed(block_out_channels))
output_channel = reversed_block_out_channels[0] output_channel = reversed_block_out_channels[0]
print(f"slicing_up_num: {slicing_up_num}")
for i, up_block_type in enumerate(up_block_types): for i, up_block_type in enumerate(up_block_types):
prev_output_channel = output_channel prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i] output_channel = reversed_block_out_channels[i]
@ -1450,6 +1449,7 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
def encode(self, x: torch.FloatTensor): def encode(self, x: torch.FloatTensor):
if x.ndim == 4: if x.ndim == 4:
x = x.unsqueeze(2) x = x.unsqueeze(2)
x = x.to(next(self.parameters()).dtype)
p = super().encode(x).latent_dist p = super().encode(x).latent_dist
z = p.sample().squeeze(2) z = p.sample().squeeze(2)
return z, p return z, p

View File

@ -268,7 +268,10 @@ class CLIP:
class VAE: class VAE:
def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None): def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None):
if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format if 'decoder.up_blocks.0.resnets.0.norm1.weight' in sd.keys(): #diffusers format
sd = diffusers_convert.convert_vae_state_dict(sd) if (metadata is not None and metadata["keep_diffusers_format"] == "true"):
pass
else:
sd = diffusers_convert.convert_vae_state_dict(sd)
self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) #These are for AutoencoderKL and need tweaking (should be lower) self.memory_used_encode = lambda shape, dtype: (1767 * shape[2] * shape[3]) * model_management.dtype_size(dtype) #These are for AutoencoderKL and need tweaking (should be lower)
self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype) self.memory_used_decode = lambda shape, dtype: (2178 * shape[2] * shape[3] * 64) * model_management.dtype_size(dtype)
@ -326,6 +329,15 @@ class VAE:
self.first_stage_model = StageC_coder() self.first_stage_model = StageC_coder()
self.downscale_ratio = 32 self.downscale_ratio = 32
self.latent_channels = 16 self.latent_channels = 16
elif "decoder.up_blocks.2.upsamplers.0.upscale_conv.weight" in sd: # seedvr2
self.first_stage_model = comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper()
self.memory_used_decode = lambda shape, dtype: (2000 * shape[1] * shape[2] * shape[3] * (4 * 8 * 8)) * model_management.dtype_size(dtype)
self.memory_used_encode = lambda shape, dtype: (1000 * max(shape[1], 5) * shape[2] * shape[3]) * model_management.dtype_size(dtype)
self.working_dtypes = [torch.bfloat16, torch.float32]
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
self.downscale_index_formula = (4, 8, 8)
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
self.upscale_index_formula = (4, 8, 8)
elif "decoder.conv_in.weight" in sd: elif "decoder.conv_in.weight" in sd:
#default SD1.x/SD2.x VAE parameters #default SD1.x/SD2.x VAE parameters
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
@ -393,17 +405,6 @@ class VAE:
self.downscale_index_formula = (8, 32, 32) self.downscale_index_formula = (8, 32, 32)
self.working_dtypes = [torch.bfloat16, torch.float32] self.working_dtypes = [torch.bfloat16, torch.float32]
elif "decoder.up_blocks.2.upsamplers.0.upscale_conv.weight" in sd: # seedvr2
self.first_stage_model = comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper()
ddconfig["conv3d"] = True
ddconfig["time_compress"] = 4
self.memory_used_decode = lambda shape, dtype: (2000 * shape[2] * shape[3] * shape[4] * (4 * 8 * 8)) * model_management.dtype_size(dtype)
self.memory_used_encode = lambda shape, dtype: (1000 * max(shape[2], 5) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
self.working_dtypes = [torch.bfloat16, torch.float32]
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
self.downscale_index_formula = (4, 8, 8)
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
self.upscale_index_formula = (4, 8, 8)
elif "decoder.conv_in.conv.weight" in sd: elif "decoder.conv_in.conv.weight" in sd:
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0} ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}

View File

@ -105,27 +105,46 @@ class SeedVR2InputProcessing(io.ComfyNode):
category="image/video", category="image/video",
inputs = [ inputs = [
io.Image.Input("images"), io.Image.Input("images"),
io.Int.Input("resolution_height"), io.Vae.Input("vae"),
io.Int.Input("resolution_width") io.Int.Input("resolution_height", default = 1280, min = 120), # //
io.Int.Input("resolution_width", default = 720, min = 120) # just non-zero value
], ],
outputs = [ outputs = [
io.Image.Output("processed_images") io.Latent.Output("vae_conditioning")
] ]
) )
@classmethod @classmethod
def execute(cls, images, resolution_height, resolution_width): def execute(cls, images, vae, resolution_height, resolution_width):
images = images.permute(0, 3, 1, 2) vae = vae.first_stage_model
scale = 0.9152; shift = 0
if images.dim() != 5: # add the t dim
images = images.unsqueeze(0)
images = images.permute(0, 1, 4, 2, 3)
b, t, c, h, w = images.shape
images = images.reshape(b * t, c, h, w)
max_area = ((resolution_height * resolution_width)** 0.5) ** 2 max_area = ((resolution_height * resolution_width)** 0.5) ** 2
clip = Lambda(lambda x: torch.clamp(x, 0.0, 1.0)) clip = Lambda(lambda x: torch.clamp(x, 0.0, 1.0))
normalize = Normalize(0.5, 0.5) normalize = Normalize(0.5, 0.5)
images = area_resize(images, max_area) images = area_resize(images, max_area)
images = clip(images) images = clip(images)
images = crop(images, (16, 16)) images = crop(images, (16, 16))
images = normalize(images) images = normalize(images)
images = rearrange(images, "t c h w -> c t h w") _, _, new_h, new_w = images.shape
images = images.reshape(b, t, c, new_h, new_w)
images = cut_videos(images) images = cut_videos(images)
return io.NodeOutput(images)
images = rearrange(images, "b t c h w -> b c t h w")
latent = vae.encode(images)[0]
latent = (latent - shift) * scale
return io.NodeOutput({"samples": latent})
class SeedVR2Conditioning(io.ComfyNode): class SeedVR2Conditioning(io.ComfyNode):
@classmethod @classmethod
@ -150,8 +169,8 @@ class SeedVR2Conditioning(io.ComfyNode):
pos_cond = text_positive_conditioning[0][0] pos_cond = text_positive_conditioning[0][0]
neg_cond = text_negative_conditioning[0][0] neg_cond = text_negative_conditioning[0][0]
noises = [torch.randn_like(latent) for latent in vae_conditioning] noises = torch.randn_like(vae_conditioning)
aug_noises = [torch.randn_like(latent) for latent in vae_conditioning] aug_noises = torch.randn_like(vae_conditioning)
cond_noise_scale = 0.0 cond_noise_scale = 0.0
t = ( t = (
@ -165,8 +184,8 @@ class SeedVR2Conditioning(io.ComfyNode):
pos_shape = pos_cond.shape[1] pos_shape = pos_cond.shape[1]
neg_shape = neg_shape.shape[1] neg_shape = neg_shape.shape[1]
diff = abs(pos_shape.shape[1] - neg_shape.shape[1]) diff = abs(pos_shape - neg_shape)
if pos_shape.shape[1] > neg_shape.shape[1]: if pos_shape > neg_shape:
neg_cond = F.pad(neg_cond, (0, 0, 0, diff)) neg_cond = F.pad(neg_cond, (0, 0, 0, diff))
else: else:
pos_cond = F.pad(pos_cond, (0, 0, 0, diff)) pos_cond = F.pad(pos_cond, (0, 0, 0, diff))