mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-23 13:00:54 +08:00
.
This commit is contained in:
parent
d12702ee0b
commit
413ee3f687
@ -1300,8 +1300,7 @@ class NaDiT(nn.Module):
|
||||
**kwargs
|
||||
):
|
||||
transformer_options = kwargs.get("transformer_options", {})
|
||||
c_or_u_list = transformer_options.get("cond_or_uncond", [])
|
||||
cond_latent = c_or_u_list[0]["condition"]
|
||||
conditions = kwargs.get("condition")
|
||||
|
||||
pos_cond, neg_cond = context.chunk(2, dim=0)
|
||||
# txt_shape should be the same for both
|
||||
@ -1312,11 +1311,16 @@ class NaDiT(nn.Module):
|
||||
|
||||
vid = x
|
||||
vid, vid_shape = flatten(x)
|
||||
cond_latent, _ = flatten(conditions)
|
||||
|
||||
vid = torch.cat([cond_latent, vid])
|
||||
vid = torch.cat([cond_latent, vid], dim=-1)
|
||||
if txt_shape.size(-1) == 1 and self.need_txt_repeat:
|
||||
txt, txt_shape = repeat(txt, txt_shape, "l c -> t l c", t=vid_shape[:, 0])
|
||||
|
||||
device = next(self.parameters()).device
|
||||
dtype = next(self.parameters()).dtype
|
||||
txt = txt.to(device).to(dtype)
|
||||
vid = vid.to(device).to(dtype)
|
||||
txt = self.txt_in(txt)
|
||||
|
||||
vid, vid_shape = self.vid_in(vid, vid_shape)
|
||||
|
||||
@ -330,6 +330,17 @@ def safe_interpolate_operation(x, size=None, scale_factor=None, mode='nearest',
|
||||
|
||||
_receptive_field_t = Literal["half", "full"]
|
||||
|
||||
def extend_head(tensor, times: int = 2, memory = None):
|
||||
if memory is not None:
|
||||
return torch.cat((memory.to(tensor), tensor), dim=2)
|
||||
assert times >= 0, "Invalid input for function 'extend_head'!"
|
||||
if times == 0:
|
||||
return tensor
|
||||
else:
|
||||
tile_repeat = [1] * tensor.ndim
|
||||
tile_repeat[2] = times
|
||||
return torch.cat(tensors=(torch.tile(tensor[:, :, :1], tile_repeat), tensor), dim=2)
|
||||
|
||||
class InflatedCausalConv3d(nn.Conv3d):
|
||||
def __init__(
|
||||
self,
|
||||
@ -348,6 +359,7 @@ class InflatedCausalConv3d(nn.Conv3d):
|
||||
self,
|
||||
input,
|
||||
):
|
||||
input = extend_head(input, times=self.temporal_padding * 2)
|
||||
return super().forward(input)
|
||||
|
||||
def _load_from_state_dict(
|
||||
@ -514,6 +526,8 @@ class Downsample3D(nn.Module):
|
||||
self.out_channels = out_channels or channels
|
||||
self.temporal_down = temporal_down
|
||||
self.spatial_down = spatial_down
|
||||
self.use_conv = use_conv
|
||||
self.padding = padding
|
||||
|
||||
self.temporal_ratio = 2 if temporal_down else 1
|
||||
self.spatial_ratio = 2 if spatial_down else 1
|
||||
@ -630,6 +644,7 @@ class ResnetBlock3D(nn.Module):
|
||||
inflation_mode=inflation_mode,
|
||||
)
|
||||
|
||||
self.upsample = self.downsample = None
|
||||
if self.up:
|
||||
self.upsample = Upsample3D(
|
||||
self.in_channels,
|
||||
@ -646,6 +661,7 @@ class ResnetBlock3D(nn.Module):
|
||||
inflation_mode=inflation_mode,
|
||||
)
|
||||
|
||||
self.conv_shortcut = None
|
||||
if self.use_in_shortcut:
|
||||
self.conv_shortcut = InflatedCausalConv3d(
|
||||
self.in_channels,
|
||||
@ -1093,6 +1109,7 @@ class Encoder3D(nn.Module):
|
||||
extra_cond=None,
|
||||
) -> torch.FloatTensor:
|
||||
r"""The forward method of the `Encoder` class."""
|
||||
sample = sample.to(next(self.parameters()).device)
|
||||
sample = self.conv_in(sample)
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
@ -1450,8 +1467,9 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
|
||||
if x.ndim == 4:
|
||||
x = x.unsqueeze(2)
|
||||
x = x.to(next(self.parameters()).dtype)
|
||||
p = super().encode(x).latent_dist
|
||||
z = p.sample().squeeze(2)
|
||||
x = x.to(next(self.parameters()).device)
|
||||
p = super().encode(x)
|
||||
z = p.squeeze(2)
|
||||
return z, p
|
||||
|
||||
def decode(self, z: torch.FloatTensor):
|
||||
|
||||
@ -798,7 +798,11 @@ class HunyuanDiT(BaseModel):
|
||||
class SeedVR2(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device, comfy.ldm.seedvr.model.NaDiT)
|
||||
# TODO: extra_conds could be needed to add
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
condition = kwargs.get("condition", None)
|
||||
out["condition"] = comfy.conds.CONDRegular(condition)
|
||||
return out
|
||||
|
||||
class PixArt(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
||||
|
||||
@ -39,7 +39,7 @@ def timestep_transform(timesteps, latents_shapes):
|
||||
frames > 1,
|
||||
vid_shift_fn(heights * widths * frames),
|
||||
img_shift_fn(heights * widths),
|
||||
)
|
||||
).to(timesteps.device)
|
||||
|
||||
# Shift timesteps.
|
||||
T = 1000.0
|
||||
@ -116,6 +116,7 @@ class SeedVR2InputProcessing(io.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def execute(cls, images, vae, resolution_height, resolution_width):
|
||||
device = vae.patcher.load_device
|
||||
vae = vae.first_stage_model
|
||||
scale = 0.9152; shift = 0
|
||||
|
||||
@ -140,6 +141,8 @@ class SeedVR2InputProcessing(io.ComfyNode):
|
||||
images = cut_videos(images)
|
||||
|
||||
images = rearrange(images, "b t c h w -> b c t h w")
|
||||
vae = vae.to(device)
|
||||
images = images.to(device)
|
||||
latent = vae.encode(images)[0]
|
||||
|
||||
latent = (latent - shift) * scale
|
||||
@ -166,24 +169,25 @@ class SeedVR2Conditioning(io.ComfyNode):
|
||||
def execute(cls, text_positive_conditioning, text_negative_conditioning, vae_conditioning) -> io.NodeOutput:
|
||||
|
||||
vae_conditioning = vae_conditioning["samples"]
|
||||
device = vae_conditioning.device
|
||||
pos_cond = text_positive_conditioning[0][0]
|
||||
neg_cond = text_negative_conditioning[0][0]
|
||||
|
||||
noises = torch.randn_like(vae_conditioning)
|
||||
aug_noises = torch.randn_like(vae_conditioning)
|
||||
noises = torch.randn_like(vae_conditioning).to(device)
|
||||
aug_noises = torch.randn_like(vae_conditioning).to(device)
|
||||
|
||||
cond_noise_scale = 0.0
|
||||
t = (
|
||||
torch.tensor([1000.0])
|
||||
* cond_noise_scale
|
||||
)
|
||||
shape = torch.tensor(vae_conditioning.shape[1:])[None]
|
||||
).to(device)
|
||||
shape = torch.tensor(vae_conditioning.shape[1:]).to(device)[None]
|
||||
t = timestep_transform(t, shape)
|
||||
cond = inter(vae_conditioning, aug_noises, t)
|
||||
condition = get_conditions(noises, cond)
|
||||
|
||||
pos_shape = pos_cond.shape[1]
|
||||
neg_shape = neg_shape.shape[1]
|
||||
neg_shape = neg_cond.shape[1]
|
||||
diff = abs(pos_shape - neg_shape)
|
||||
if pos_shape > neg_shape:
|
||||
neg_cond = F.pad(neg_cond, (0, 0, 0, diff))
|
||||
|
||||
Loading…
Reference in New Issue
Block a user