mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-23 21:10:49 +08:00
.
This commit is contained in:
parent
d12702ee0b
commit
413ee3f687
@ -1300,8 +1300,7 @@ class NaDiT(nn.Module):
|
|||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
transformer_options = kwargs.get("transformer_options", {})
|
transformer_options = kwargs.get("transformer_options", {})
|
||||||
c_or_u_list = transformer_options.get("cond_or_uncond", [])
|
conditions = kwargs.get("condition")
|
||||||
cond_latent = c_or_u_list[0]["condition"]
|
|
||||||
|
|
||||||
pos_cond, neg_cond = context.chunk(2, dim=0)
|
pos_cond, neg_cond = context.chunk(2, dim=0)
|
||||||
# txt_shape should be the same for both
|
# txt_shape should be the same for both
|
||||||
@ -1312,11 +1311,16 @@ class NaDiT(nn.Module):
|
|||||||
|
|
||||||
vid = x
|
vid = x
|
||||||
vid, vid_shape = flatten(x)
|
vid, vid_shape = flatten(x)
|
||||||
|
cond_latent, _ = flatten(conditions)
|
||||||
|
|
||||||
vid = torch.cat([cond_latent, vid])
|
vid = torch.cat([cond_latent, vid], dim=-1)
|
||||||
if txt_shape.size(-1) == 1 and self.need_txt_repeat:
|
if txt_shape.size(-1) == 1 and self.need_txt_repeat:
|
||||||
txt, txt_shape = repeat(txt, txt_shape, "l c -> t l c", t=vid_shape[:, 0])
|
txt, txt_shape = repeat(txt, txt_shape, "l c -> t l c", t=vid_shape[:, 0])
|
||||||
|
|
||||||
|
device = next(self.parameters()).device
|
||||||
|
dtype = next(self.parameters()).dtype
|
||||||
|
txt = txt.to(device).to(dtype)
|
||||||
|
vid = vid.to(device).to(dtype)
|
||||||
txt = self.txt_in(txt)
|
txt = self.txt_in(txt)
|
||||||
|
|
||||||
vid, vid_shape = self.vid_in(vid, vid_shape)
|
vid, vid_shape = self.vid_in(vid, vid_shape)
|
||||||
|
|||||||
@ -330,6 +330,17 @@ def safe_interpolate_operation(x, size=None, scale_factor=None, mode='nearest',
|
|||||||
|
|
||||||
_receptive_field_t = Literal["half", "full"]
|
_receptive_field_t = Literal["half", "full"]
|
||||||
|
|
||||||
|
def extend_head(tensor, times: int = 2, memory = None):
|
||||||
|
if memory is not None:
|
||||||
|
return torch.cat((memory.to(tensor), tensor), dim=2)
|
||||||
|
assert times >= 0, "Invalid input for function 'extend_head'!"
|
||||||
|
if times == 0:
|
||||||
|
return tensor
|
||||||
|
else:
|
||||||
|
tile_repeat = [1] * tensor.ndim
|
||||||
|
tile_repeat[2] = times
|
||||||
|
return torch.cat(tensors=(torch.tile(tensor[:, :, :1], tile_repeat), tensor), dim=2)
|
||||||
|
|
||||||
class InflatedCausalConv3d(nn.Conv3d):
|
class InflatedCausalConv3d(nn.Conv3d):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@ -348,6 +359,7 @@ class InflatedCausalConv3d(nn.Conv3d):
|
|||||||
self,
|
self,
|
||||||
input,
|
input,
|
||||||
):
|
):
|
||||||
|
input = extend_head(input, times=self.temporal_padding * 2)
|
||||||
return super().forward(input)
|
return super().forward(input)
|
||||||
|
|
||||||
def _load_from_state_dict(
|
def _load_from_state_dict(
|
||||||
@ -514,6 +526,8 @@ class Downsample3D(nn.Module):
|
|||||||
self.out_channels = out_channels or channels
|
self.out_channels = out_channels or channels
|
||||||
self.temporal_down = temporal_down
|
self.temporal_down = temporal_down
|
||||||
self.spatial_down = spatial_down
|
self.spatial_down = spatial_down
|
||||||
|
self.use_conv = use_conv
|
||||||
|
self.padding = padding
|
||||||
|
|
||||||
self.temporal_ratio = 2 if temporal_down else 1
|
self.temporal_ratio = 2 if temporal_down else 1
|
||||||
self.spatial_ratio = 2 if spatial_down else 1
|
self.spatial_ratio = 2 if spatial_down else 1
|
||||||
@ -630,6 +644,7 @@ class ResnetBlock3D(nn.Module):
|
|||||||
inflation_mode=inflation_mode,
|
inflation_mode=inflation_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.upsample = self.downsample = None
|
||||||
if self.up:
|
if self.up:
|
||||||
self.upsample = Upsample3D(
|
self.upsample = Upsample3D(
|
||||||
self.in_channels,
|
self.in_channels,
|
||||||
@ -646,6 +661,7 @@ class ResnetBlock3D(nn.Module):
|
|||||||
inflation_mode=inflation_mode,
|
inflation_mode=inflation_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.conv_shortcut = None
|
||||||
if self.use_in_shortcut:
|
if self.use_in_shortcut:
|
||||||
self.conv_shortcut = InflatedCausalConv3d(
|
self.conv_shortcut = InflatedCausalConv3d(
|
||||||
self.in_channels,
|
self.in_channels,
|
||||||
@ -1093,6 +1109,7 @@ class Encoder3D(nn.Module):
|
|||||||
extra_cond=None,
|
extra_cond=None,
|
||||||
) -> torch.FloatTensor:
|
) -> torch.FloatTensor:
|
||||||
r"""The forward method of the `Encoder` class."""
|
r"""The forward method of the `Encoder` class."""
|
||||||
|
sample = sample.to(next(self.parameters()).device)
|
||||||
sample = self.conv_in(sample)
|
sample = self.conv_in(sample)
|
||||||
if self.training and self.gradient_checkpointing:
|
if self.training and self.gradient_checkpointing:
|
||||||
|
|
||||||
@ -1450,8 +1467,9 @@ class VideoAutoencoderKLWrapper(VideoAutoencoderKL):
|
|||||||
if x.ndim == 4:
|
if x.ndim == 4:
|
||||||
x = x.unsqueeze(2)
|
x = x.unsqueeze(2)
|
||||||
x = x.to(next(self.parameters()).dtype)
|
x = x.to(next(self.parameters()).dtype)
|
||||||
p = super().encode(x).latent_dist
|
x = x.to(next(self.parameters()).device)
|
||||||
z = p.sample().squeeze(2)
|
p = super().encode(x)
|
||||||
|
z = p.squeeze(2)
|
||||||
return z, p
|
return z, p
|
||||||
|
|
||||||
def decode(self, z: torch.FloatTensor):
|
def decode(self, z: torch.FloatTensor):
|
||||||
|
|||||||
@ -798,7 +798,11 @@ class HunyuanDiT(BaseModel):
|
|||||||
class SeedVR2(BaseModel):
|
class SeedVR2(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
super().__init__(model_config, model_type, device, comfy.ldm.seedvr.model.NaDiT)
|
super().__init__(model_config, model_type, device, comfy.ldm.seedvr.model.NaDiT)
|
||||||
# TODO: extra_conds could be needed to add
|
def extra_conds(self, **kwargs):
|
||||||
|
out = super().extra_conds(**kwargs)
|
||||||
|
condition = kwargs.get("condition", None)
|
||||||
|
out["condition"] = comfy.conds.CONDRegular(condition)
|
||||||
|
return out
|
||||||
|
|
||||||
class PixArt(BaseModel):
|
class PixArt(BaseModel):
|
||||||
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
def __init__(self, model_config, model_type=ModelType.EPS, device=None):
|
||||||
|
|||||||
@ -39,7 +39,7 @@ def timestep_transform(timesteps, latents_shapes):
|
|||||||
frames > 1,
|
frames > 1,
|
||||||
vid_shift_fn(heights * widths * frames),
|
vid_shift_fn(heights * widths * frames),
|
||||||
img_shift_fn(heights * widths),
|
img_shift_fn(heights * widths),
|
||||||
)
|
).to(timesteps.device)
|
||||||
|
|
||||||
# Shift timesteps.
|
# Shift timesteps.
|
||||||
T = 1000.0
|
T = 1000.0
|
||||||
@ -116,6 +116,7 @@ class SeedVR2InputProcessing(io.ComfyNode):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, images, vae, resolution_height, resolution_width):
|
def execute(cls, images, vae, resolution_height, resolution_width):
|
||||||
|
device = vae.patcher.load_device
|
||||||
vae = vae.first_stage_model
|
vae = vae.first_stage_model
|
||||||
scale = 0.9152; shift = 0
|
scale = 0.9152; shift = 0
|
||||||
|
|
||||||
@ -140,6 +141,8 @@ class SeedVR2InputProcessing(io.ComfyNode):
|
|||||||
images = cut_videos(images)
|
images = cut_videos(images)
|
||||||
|
|
||||||
images = rearrange(images, "b t c h w -> b c t h w")
|
images = rearrange(images, "b t c h w -> b c t h w")
|
||||||
|
vae = vae.to(device)
|
||||||
|
images = images.to(device)
|
||||||
latent = vae.encode(images)[0]
|
latent = vae.encode(images)[0]
|
||||||
|
|
||||||
latent = (latent - shift) * scale
|
latent = (latent - shift) * scale
|
||||||
@ -166,24 +169,25 @@ class SeedVR2Conditioning(io.ComfyNode):
|
|||||||
def execute(cls, text_positive_conditioning, text_negative_conditioning, vae_conditioning) -> io.NodeOutput:
|
def execute(cls, text_positive_conditioning, text_negative_conditioning, vae_conditioning) -> io.NodeOutput:
|
||||||
|
|
||||||
vae_conditioning = vae_conditioning["samples"]
|
vae_conditioning = vae_conditioning["samples"]
|
||||||
|
device = vae_conditioning.device
|
||||||
pos_cond = text_positive_conditioning[0][0]
|
pos_cond = text_positive_conditioning[0][0]
|
||||||
neg_cond = text_negative_conditioning[0][0]
|
neg_cond = text_negative_conditioning[0][0]
|
||||||
|
|
||||||
noises = torch.randn_like(vae_conditioning)
|
noises = torch.randn_like(vae_conditioning).to(device)
|
||||||
aug_noises = torch.randn_like(vae_conditioning)
|
aug_noises = torch.randn_like(vae_conditioning).to(device)
|
||||||
|
|
||||||
cond_noise_scale = 0.0
|
cond_noise_scale = 0.0
|
||||||
t = (
|
t = (
|
||||||
torch.tensor([1000.0])
|
torch.tensor([1000.0])
|
||||||
* cond_noise_scale
|
* cond_noise_scale
|
||||||
)
|
).to(device)
|
||||||
shape = torch.tensor(vae_conditioning.shape[1:])[None]
|
shape = torch.tensor(vae_conditioning.shape[1:]).to(device)[None]
|
||||||
t = timestep_transform(t, shape)
|
t = timestep_transform(t, shape)
|
||||||
cond = inter(vae_conditioning, aug_noises, t)
|
cond = inter(vae_conditioning, aug_noises, t)
|
||||||
condition = get_conditions(noises, cond)
|
condition = get_conditions(noises, cond)
|
||||||
|
|
||||||
pos_shape = pos_cond.shape[1]
|
pos_shape = pos_cond.shape[1]
|
||||||
neg_shape = neg_shape.shape[1]
|
neg_shape = neg_cond.shape[1]
|
||||||
diff = abs(pos_shape - neg_shape)
|
diff = abs(pos_shape - neg_shape)
|
||||||
if pos_shape > neg_shape:
|
if pos_shape > neg_shape:
|
||||||
neg_cond = F.pad(neg_cond, (0, 0, 0, diff))
|
neg_cond = F.pad(neg_cond, (0, 0, 0, diff))
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user