mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-24 21:30:15 +08:00
Fix linting issues
This commit is contained in:
parent
66cf9b41f2
commit
dd9a781654
@ -3,8 +3,9 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torchaudio
|
import torchaudio
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from comfy.ldm.modules.attention import optimized_attention_masked
|
from ..ldm.modules.attention import optimized_attention_masked
|
||||||
import comfy.ops
|
from .. import ops
|
||||||
|
|
||||||
|
|
||||||
class WhisperFeatureExtractor(nn.Module):
|
class WhisperFeatureExtractor(nn.Module):
|
||||||
def __init__(self, n_mels=128, device=None):
|
def __init__(self, n_mels=128, device=None):
|
||||||
@ -66,11 +67,11 @@ class MultiHeadAttention(nn.Module):
|
|||||||
self.out_proj = operations.Linear(d_model, d_model, dtype=dtype, device=device)
|
self.out_proj = operations.Linear(d_model, d_model, dtype=dtype, device=device)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
query: torch.Tensor,
|
query: torch.Tensor,
|
||||||
key: torch.Tensor,
|
key: torch.Tensor,
|
||||||
value: torch.Tensor,
|
value: torch.Tensor,
|
||||||
mask: Optional[torch.Tensor] = None,
|
mask: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
batch_size, seq_len, _ = query.shape
|
batch_size, seq_len, _ = query.shape
|
||||||
|
|
||||||
@ -96,9 +97,9 @@ class EncoderLayer(nn.Module):
|
|||||||
self.final_layer_norm = operations.LayerNorm(d_model, dtype=dtype, device=device)
|
self.final_layer_norm = operations.LayerNorm(d_model, dtype=dtype, device=device)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
attention_mask: Optional[torch.Tensor] = None
|
attention_mask: Optional[torch.Tensor] = None
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
residual = x
|
residual = x
|
||||||
x = self.self_attn_layer_norm(x)
|
x = self.self_attn_layer_norm(x)
|
||||||
@ -117,15 +118,15 @@ class EncoderLayer(nn.Module):
|
|||||||
|
|
||||||
class AudioEncoder(nn.Module):
|
class AudioEncoder(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
n_mels: int = 128,
|
n_mels: int = 128,
|
||||||
n_ctx: int = 1500,
|
n_ctx: int = 1500,
|
||||||
n_state: int = 1280,
|
n_state: int = 1280,
|
||||||
n_head: int = 20,
|
n_head: int = 20,
|
||||||
n_layer: int = 32,
|
n_layer: int = 32,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
device=None,
|
device=None,
|
||||||
operations=None
|
operations=None
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -147,7 +148,7 @@ class AudioEncoder(nn.Module):
|
|||||||
|
|
||||||
x = x.transpose(1, 2)
|
x = x.transpose(1, 2)
|
||||||
|
|
||||||
x = x + comfy.ops.cast_to_input(self.embed_positions.weight[:, :x.shape[1]], x)
|
x = x + ops.cast_to_input(self.embed_positions.weight[:, :x.shape[1]], x)
|
||||||
|
|
||||||
all_x = ()
|
all_x = ()
|
||||||
for layer in self.layers:
|
for layer in self.layers:
|
||||||
@ -161,15 +162,15 @@ class AudioEncoder(nn.Module):
|
|||||||
|
|
||||||
class WhisperLargeV3(nn.Module):
|
class WhisperLargeV3(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
n_mels: int = 128,
|
n_mels: int = 128,
|
||||||
n_audio_ctx: int = 1500,
|
n_audio_ctx: int = 1500,
|
||||||
n_audio_state: int = 1280,
|
n_audio_state: int = 1280,
|
||||||
n_audio_head: int = 20,
|
n_audio_head: int = 20,
|
||||||
n_audio_layer: int = 32,
|
n_audio_layer: int = 32,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
device=None,
|
device=None,
|
||||||
operations=None
|
operations=None
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
# Credits:
|
# Credits:
|
||||||
# Original Flux code can be found on: https://github.com/black-forest-labs/flux
|
# Original Flux code can be found on: https://github.com/black-forest-labs/flux
|
||||||
# Chroma Radiance adaption referenced from https://github.com/lodestone-rock/flow
|
# Chroma Radiance adaption referenced from https://github.com/lodestone-rock/flow
|
||||||
|
import dataclasses
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@ -66,6 +66,8 @@ class ChromaRadiance(Chroma):
|
|||||||
self.hidden_dim = params.hidden_dim
|
self.hidden_dim = params.hidden_dim
|
||||||
self.n_layers = params.n_layers
|
self.n_layers = params.n_layers
|
||||||
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
self.pe_embedder = EmbedND(dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim)
|
||||||
|
# replaces the operation
|
||||||
|
self.img_in = self._img_in
|
||||||
self.img_in_patch = operations.Conv2d(
|
self.img_in_patch = operations.Conv2d(
|
||||||
params.in_channels,
|
params.in_channels,
|
||||||
params.hidden_size,
|
params.hidden_size,
|
||||||
@ -164,7 +166,7 @@ class ChromaRadiance(Chroma):
|
|||||||
# Impossible to get here as we raise an error on unexpected types on initialization.
|
# Impossible to get here as we raise an error on unexpected types on initialization.
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def img_in(self, img: Tensor) -> Tensor:
|
def _img_in(self, img: Tensor) -> Tensor:
|
||||||
img = self.img_in_patch(img) # -> [B, Hidden, H/P, W/P]
|
img = self.img_in_patch(img) # -> [B, Hidden, H/P, W/P]
|
||||||
# flatten into a sequence for the transformer.
|
# flatten into a sequence for the transformer.
|
||||||
return img.flatten(2).transpose(1, 2) # -> [B, NumPatches, Hidden]
|
return img.flatten(2).transpose(1, 2) # -> [B, NumPatches, Hidden]
|
||||||
@ -263,7 +265,7 @@ class ChromaRadiance(Chroma):
|
|||||||
params = self.params
|
params = self.params
|
||||||
if not overrides:
|
if not overrides:
|
||||||
return params
|
return params
|
||||||
params_dict = {k: getattr(params, k) for k in params.__dataclass_fields__}
|
params_dict = dataclasses.asdict(params)
|
||||||
nullable_keys = frozenset(("nerf_embedder_dtype",))
|
nullable_keys = frozenset(("nerf_embedder_dtype",))
|
||||||
bad_keys = tuple(k for k in overrides if k not in params_dict)
|
bad_keys = tuple(k for k in overrides if k not in params_dict)
|
||||||
if bad_keys:
|
if bad_keys:
|
||||||
|
|||||||
@ -1,8 +1,7 @@
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock
|
from ..modules.diffusionmodules.model import ResnetBlock, AttnBlock
|
||||||
import comfy.ops
|
from ...ops import disable_weight_init as ops
|
||||||
ops = comfy.ops.disable_weight_init
|
|
||||||
|
|
||||||
|
|
||||||
class PixelShuffle2D(nn.Module):
|
class PixelShuffle2D(nn.Module):
|
||||||
@ -52,7 +51,7 @@ class Encoder(nn.Module):
|
|||||||
out_channels=tgt,
|
out_channels=tgt,
|
||||||
temb_channels=0,
|
temb_channels=0,
|
||||||
conv_op=ops.Conv2d)
|
conv_op=ops.Conv2d)
|
||||||
for j in range(num_res_blocks)])
|
for j in range(num_res_blocks)])
|
||||||
ch = tgt
|
ch = tgt
|
||||||
if i < depth:
|
if i < depth:
|
||||||
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and downsample_match_channel else ch
|
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and downsample_match_channel else ch
|
||||||
@ -112,7 +111,7 @@ class Decoder(nn.Module):
|
|||||||
out_channels=tgt,
|
out_channels=tgt,
|
||||||
temb_channels=0,
|
temb_channels=0,
|
||||||
conv_op=ops.Conv2d)
|
conv_op=ops.Conv2d)
|
||||||
for j in range(num_res_blocks + 1)])
|
for j in range(num_res_blocks + 1)])
|
||||||
ch = tgt
|
ch = tgt
|
||||||
if i < depth:
|
if i < depth:
|
||||||
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and upsample_match_channel else ch
|
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and upsample_match_channel else ch
|
||||||
|
|||||||
@ -1,21 +1,22 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, AttnBlock, VideoConv3d
|
from ..modules.diffusionmodules.model import ResnetBlock, AttnBlock, VideoConv3d
|
||||||
import comfy.ops
|
from ..models.autoencoder import DiagonalGaussianRegularizer
|
||||||
import comfy.ldm.models.autoencoder
|
from ...ops import disable_weight_init as ops
|
||||||
ops = comfy.ops.disable_weight_init
|
|
||||||
|
|
||||||
class RMS_norm(nn.Module):
|
class RMS_norm(nn.Module):
|
||||||
def __init__(self, dim):
|
def __init__(self, dim):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
shape = (dim, 1, 1, 1)
|
shape = (dim, 1, 1, 1)
|
||||||
self.scale = dim**0.5
|
self.scale = dim ** 0.5
|
||||||
self.gamma = nn.Parameter(torch.empty(shape))
|
self.gamma = nn.Parameter(torch.empty(shape))
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return F.normalize(x, dim=1) * self.scale * self.gamma
|
return F.normalize(x, dim=1) * self.scale * self.gamma
|
||||||
|
|
||||||
|
|
||||||
class DnSmpl(nn.Module):
|
class DnSmpl(nn.Module):
|
||||||
def __init__(self, ic, oc, tds=True):
|
def __init__(self, ic, oc, tds=True):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -146,6 +147,7 @@ class UpSmpl(nn.Module):
|
|||||||
|
|
||||||
return h + sc
|
return h + sc
|
||||||
|
|
||||||
|
|
||||||
class Encoder(nn.Module):
|
class Encoder(nn.Module):
|
||||||
def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks,
|
def __init__(self, in_channels, z_channels, block_out_channels, num_res_blocks,
|
||||||
ffactor_spatial, ffactor_temporal, downsample_match_channel=True, **_):
|
ffactor_spatial, ffactor_temporal, downsample_match_channel=True, **_):
|
||||||
@ -166,7 +168,7 @@ class Encoder(nn.Module):
|
|||||||
out_channels=tgt,
|
out_channels=tgt,
|
||||||
temb_channels=0,
|
temb_channels=0,
|
||||||
conv_op=VideoConv3d, norm_op=RMS_norm)
|
conv_op=VideoConv3d, norm_op=RMS_norm)
|
||||||
for j in range(num_res_blocks)])
|
for j in range(num_res_blocks)])
|
||||||
ch = tgt
|
ch = tgt
|
||||||
if i < depth:
|
if i < depth:
|
||||||
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and downsample_match_channel else ch
|
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and downsample_match_channel else ch
|
||||||
@ -182,7 +184,7 @@ class Encoder(nn.Module):
|
|||||||
self.norm_out = RMS_norm(ch)
|
self.norm_out = RMS_norm(ch)
|
||||||
self.conv_out = VideoConv3d(ch, z_channels << 1, 3, 1, 1)
|
self.conv_out = VideoConv3d(ch, z_channels << 1, 3, 1, 1)
|
||||||
|
|
||||||
self.regul = comfy.ldm.models.autoencoder.DiagonalGaussianRegularizer()
|
self.regul = DiagonalGaussianRegularizer()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.conv_in(x)
|
x = self.conv_in(x)
|
||||||
@ -209,6 +211,7 @@ class Encoder(nn.Module):
|
|||||||
out = out.permute(0, 2, 1, 3, 4).contiguous()
|
out = out.permute(0, 2, 1, 3, 4).contiguous()
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
class Decoder(nn.Module):
|
class Decoder(nn.Module):
|
||||||
def __init__(self, z_channels, out_channels, block_out_channels, num_res_blocks,
|
def __init__(self, z_channels, out_channels, block_out_channels, num_res_blocks,
|
||||||
ffactor_spatial, ffactor_temporal, upsample_match_channel=True, **_):
|
ffactor_spatial, ffactor_temporal, upsample_match_channel=True, **_):
|
||||||
@ -236,7 +239,7 @@ class Decoder(nn.Module):
|
|||||||
out_channels=tgt,
|
out_channels=tgt,
|
||||||
temb_channels=0,
|
temb_channels=0,
|
||||||
conv_op=VideoConv3d, norm_op=RMS_norm)
|
conv_op=VideoConv3d, norm_op=RMS_norm)
|
||||||
for j in range(num_res_blocks + 1)])
|
for j in range(num_res_blocks + 1)])
|
||||||
ch = tgt
|
ch = tgt
|
||||||
if i < depth:
|
if i < depth:
|
||||||
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and upsample_match_channel else ch
|
nxt = block_out_channels[i + 1] if i + 1 < len(block_out_channels) and upsample_match_channel else ch
|
||||||
|
|||||||
@ -5,8 +5,9 @@ from einops import rearrange
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import math
|
import math
|
||||||
from .model import WanModel, sinusoidal_embedding_1d
|
from .model import WanModel, sinusoidal_embedding_1d
|
||||||
from comfy.ldm.modules.attention import optimized_attention
|
from ..modules.attention import optimized_attention
|
||||||
import comfy.model_management
|
from ...model_management import cast_to
|
||||||
|
|
||||||
|
|
||||||
class CausalConv1d(nn.Module):
|
class CausalConv1d(nn.Module):
|
||||||
|
|
||||||
@ -46,7 +47,6 @@ class FaceEncoder(nn.Module):
|
|||||||
self.padding_tokens = nn.Parameter(torch.empty(1, 1, 1, hidden_dim, **factory_kwargs))
|
self.padding_tokens = nn.Parameter(torch.empty(1, 1, 1, hidden_dim, **factory_kwargs))
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
|
||||||
x = rearrange(x, "b t c -> b c t")
|
x = rearrange(x, "b t c -> b c t")
|
||||||
b, c, t = x.shape
|
b, c, t = x.shape
|
||||||
|
|
||||||
@ -67,7 +67,7 @@ class FaceEncoder(nn.Module):
|
|||||||
x = self.act(x)
|
x = self.act(x)
|
||||||
x = self.out_proj(x)
|
x = self.out_proj(x)
|
||||||
x = rearrange(x, "(b n) t c -> b t n c", b=b)
|
x = rearrange(x, "(b n) t c -> b t n c", b=b)
|
||||||
padding = comfy.model_management.cast_to(self.padding_tokens, dtype=x.dtype, device=x.device).repeat(b, x.shape[1], 1, 1)
|
padding = cast_to(self.padding_tokens, dtype=x.dtype, device=x.device).repeat(b, x.shape[1], 1, 1)
|
||||||
x = torch.cat([x, padding], dim=-2)
|
x = torch.cat([x, padding], dim=-2)
|
||||||
x_local = x.clone()
|
x_local = x.clone()
|
||||||
|
|
||||||
@ -94,15 +94,14 @@ def get_norm_layer(norm_layer, operations=None):
|
|||||||
|
|
||||||
class FaceAdapter(nn.Module):
|
class FaceAdapter(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
hidden_dim: int,
|
hidden_dim: int,
|
||||||
heads_num: int,
|
heads_num: int,
|
||||||
qk_norm: bool = True,
|
qk_norm: bool = True,
|
||||||
qk_norm_type: str = "rms",
|
qk_norm_type: str = "rms",
|
||||||
num_adapter_layers: int = 1,
|
num_adapter_layers: int = 1,
|
||||||
dtype=None, device=None, operations=None
|
dtype=None, device=None, operations=None
|
||||||
):
|
):
|
||||||
|
|
||||||
factory_kwargs = {"dtype": dtype, "device": device}
|
factory_kwargs = {"dtype": dtype, "device": device}
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = hidden_dim
|
self.hidden_size = hidden_dim
|
||||||
@ -122,29 +121,27 @@ class FaceAdapter(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
motion_embed: torch.Tensor,
|
motion_embed: torch.Tensor,
|
||||||
idx: int,
|
idx: int,
|
||||||
freqs_cis_q: Tuple[torch.Tensor, torch.Tensor] = None,
|
freqs_cis_q: Tuple[torch.Tensor, torch.Tensor] = None,
|
||||||
freqs_cis_k: Tuple[torch.Tensor, torch.Tensor] = None,
|
freqs_cis_k: Tuple[torch.Tensor, torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
return self.fuser_blocks[idx](x, motion_embed, freqs_cis_q, freqs_cis_k)
|
return self.fuser_blocks[idx](x, motion_embed, freqs_cis_q, freqs_cis_k)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class FaceBlock(nn.Module):
|
class FaceBlock(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
heads_num: int,
|
heads_num: int,
|
||||||
qk_norm: bool = True,
|
qk_norm: bool = True,
|
||||||
qk_norm_type: str = "rms",
|
qk_norm_type: str = "rms",
|
||||||
qk_scale: float = None,
|
qk_scale: float = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
device: Optional[torch.device] = None,
|
device: Optional[torch.device] = None,
|
||||||
operations=None
|
operations=None
|
||||||
):
|
):
|
||||||
factory_kwargs = {"device": device, "dtype": dtype}
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -153,7 +150,7 @@ class FaceBlock(nn.Module):
|
|||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.heads_num = heads_num
|
self.heads_num = heads_num
|
||||||
head_dim = hidden_size // heads_num
|
head_dim = hidden_size // heads_num
|
||||||
self.scale = qk_scale or head_dim**-0.5
|
self.scale = qk_scale or head_dim ** -0.5
|
||||||
|
|
||||||
self.linear1_kv = operations.Linear(hidden_size, hidden_size * 2, **factory_kwargs)
|
self.linear1_kv = operations.Linear(hidden_size, hidden_size * 2, **factory_kwargs)
|
||||||
self.linear1_q = operations.Linear(hidden_size, hidden_size, **factory_kwargs)
|
self.linear1_q = operations.Linear(hidden_size, hidden_size, **factory_kwargs)
|
||||||
@ -173,13 +170,12 @@ class FaceBlock(nn.Module):
|
|||||||
self.pre_norm_motion = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
self.pre_norm_motion = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
motion_vec: torch.Tensor,
|
motion_vec: torch.Tensor,
|
||||||
motion_mask: Optional[torch.Tensor] = None,
|
motion_mask: Optional[torch.Tensor] = None,
|
||||||
# use_context_parallel=False,
|
# use_context_parallel=False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
B, T, N, C = motion_vec.shape
|
B, T, N, C = motion_vec.shape
|
||||||
T_comp = T
|
T_comp = T
|
||||||
|
|
||||||
@ -212,6 +208,7 @@ class FaceBlock(nn.Module):
|
|||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
# https://github.com/XPixelGroup/BasicSR/blob/8d56e3a045f9fb3e1d8872f92ee4a4f07f886b0a/basicsr/ops/upfirdn2d/upfirdn2d.py#L162
|
# https://github.com/XPixelGroup/BasicSR/blob/8d56e3a045f9fb3e1d8872f92ee4a4f07f886b0a/basicsr/ops/upfirdn2d/upfirdn2d.py#L162
|
||||||
def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1):
|
def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1):
|
||||||
_, minor, in_h, in_w = input.shape
|
_, minor, in_h, in_w = input.shape
|
||||||
@ -230,9 +227,11 @@ def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1,
|
|||||||
out = out.reshape(-1, minor, in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1)
|
out = out.reshape(-1, minor, in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1)
|
||||||
return out[:, :, ::down_y, ::down_x]
|
return out[:, :, ::down_y, ::down_x]
|
||||||
|
|
||||||
|
|
||||||
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
|
def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
|
||||||
return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
|
return upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1])
|
||||||
|
|
||||||
|
|
||||||
# https://github.com/XPixelGroup/BasicSR/blob/8d56e3a045f9fb3e1d8872f92ee4a4f07f886b0a/basicsr/ops/fused_act/fused_act.py#L81
|
# https://github.com/XPixelGroup/BasicSR/blob/8d56e3a045f9fb3e1d8872f92ee4a4f07f886b0a/basicsr/ops/fused_act/fused_act.py#L81
|
||||||
class FusedLeakyReLU(torch.nn.Module):
|
class FusedLeakyReLU(torch.nn.Module):
|
||||||
def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5, dtype=None, device=None):
|
def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5, dtype=None, device=None):
|
||||||
@ -242,11 +241,13 @@ class FusedLeakyReLU(torch.nn.Module):
|
|||||||
self.scale = scale
|
self.scale = scale
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
return fused_leaky_relu(input, comfy.model_management.cast_to(self.bias, device=input.device, dtype=input.dtype), self.negative_slope, self.scale)
|
return fused_leaky_relu(input, cast_to(self.bias, device=input.device, dtype=input.dtype), self.negative_slope, self.scale)
|
||||||
|
|
||||||
|
|
||||||
def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
|
def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
|
||||||
return F.leaky_relu(input + bias, negative_slope) * scale
|
return F.leaky_relu(input + bias, negative_slope) * scale
|
||||||
|
|
||||||
|
|
||||||
class Blur(torch.nn.Module):
|
class Blur(torch.nn.Module):
|
||||||
def __init__(self, kernel, pad, dtype=None, device=None):
|
def __init__(self, kernel, pad, dtype=None, device=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -257,9 +258,10 @@ class Blur(torch.nn.Module):
|
|||||||
self.pad = pad
|
self.pad = pad
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
return upfirdn2d(input, comfy.model_management.cast_to(self.kernel, dtype=input.dtype, device=input.device), pad=self.pad)
|
return upfirdn2d(input, cast_to(self.kernel, dtype=input.dtype, device=input.device), pad=self.pad)
|
||||||
|
|
||||||
#https://github.com/XPixelGroup/BasicSR/blob/8d56e3a045f9fb3e1d8872f92ee4a4f07f886b0a/basicsr/archs/stylegan2_arch.py#L590
|
|
||||||
|
# https://github.com/XPixelGroup/BasicSR/blob/8d56e3a045f9fb3e1d8872f92ee4a4f07f886b0a/basicsr/archs/stylegan2_arch.py#L590
|
||||||
class ScaledLeakyReLU(torch.nn.Module):
|
class ScaledLeakyReLU(torch.nn.Module):
|
||||||
def __init__(self, negative_slope=0.2):
|
def __init__(self, negative_slope=0.2):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -268,6 +270,7 @@ class ScaledLeakyReLU(torch.nn.Module):
|
|||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
return F.leaky_relu(input, negative_slope=self.negative_slope)
|
return F.leaky_relu(input, negative_slope=self.negative_slope)
|
||||||
|
|
||||||
|
|
||||||
# https://github.com/XPixelGroup/BasicSR/blob/8d56e3a045f9fb3e1d8872f92ee4a4f07f886b0a/basicsr/archs/stylegan2_arch.py#L605
|
# https://github.com/XPixelGroup/BasicSR/blob/8d56e3a045f9fb3e1d8872f92ee4a4f07f886b0a/basicsr/archs/stylegan2_arch.py#L605
|
||||||
class EqualConv2d(torch.nn.Module):
|
class EqualConv2d(torch.nn.Module):
|
||||||
def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True, dtype=None, device=None, operations=None):
|
def __init__(self, in_channel, out_channel, kernel_size, stride=1, padding=0, bias=True, dtype=None, device=None, operations=None):
|
||||||
@ -282,9 +285,10 @@ class EqualConv2d(torch.nn.Module):
|
|||||||
if self.bias is None:
|
if self.bias is None:
|
||||||
bias = None
|
bias = None
|
||||||
else:
|
else:
|
||||||
bias = comfy.model_management.cast_to(self.bias, device=input.device, dtype=input.dtype)
|
bias = cast_to(self.bias, device=input.device, dtype=input.dtype)
|
||||||
|
|
||||||
|
return F.conv2d(input, cast_to(self.weight, device=input.device, dtype=input.dtype) * self.scale, bias=bias, stride=self.stride, padding=self.padding)
|
||||||
|
|
||||||
return F.conv2d(input, comfy.model_management.cast_to(self.weight, device=input.device, dtype=input.dtype) * self.scale, bias=bias, stride=self.stride, padding=self.padding)
|
|
||||||
|
|
||||||
# https://github.com/XPixelGroup/BasicSR/blob/8d56e3a045f9fb3e1d8872f92ee4a4f07f886b0a/basicsr/archs/stylegan2_arch.py#L134
|
# https://github.com/XPixelGroup/BasicSR/blob/8d56e3a045f9fb3e1d8872f92ee4a4f07f886b0a/basicsr/archs/stylegan2_arch.py#L134
|
||||||
class EqualLinear(torch.nn.Module):
|
class EqualLinear(torch.nn.Module):
|
||||||
@ -300,12 +304,13 @@ class EqualLinear(torch.nn.Module):
|
|||||||
if self.bias is None:
|
if self.bias is None:
|
||||||
bias = None
|
bias = None
|
||||||
else:
|
else:
|
||||||
bias = comfy.model_management.cast_to(self.bias, device=input.device, dtype=input.dtype) * self.lr_mul
|
bias = cast_to(self.bias, device=input.device, dtype=input.dtype) * self.lr_mul
|
||||||
|
|
||||||
if self.activation:
|
if self.activation:
|
||||||
out = F.linear(input, comfy.model_management.cast_to(self.weight, device=input.device, dtype=input.dtype) * self.scale)
|
out = F.linear(input, cast_to(self.weight, device=input.device, dtype=input.dtype) * self.scale)
|
||||||
return fused_leaky_relu(out, bias)
|
return fused_leaky_relu(out, bias)
|
||||||
return F.linear(input, comfy.model_management.cast_to(self.weight, device=input.device, dtype=input.dtype) * self.scale, bias=bias)
|
return F.linear(input, cast_to(self.weight, device=input.device, dtype=input.dtype) * self.scale, bias=bias)
|
||||||
|
|
||||||
|
|
||||||
# https://github.com/XPixelGroup/BasicSR/blob/8d56e3a045f9fb3e1d8872f92ee4a4f07f886b0a/basicsr/archs/stylegan2_arch.py#L654
|
# https://github.com/XPixelGroup/BasicSR/blob/8d56e3a045f9fb3e1d8872f92ee4a4f07f886b0a/basicsr/archs/stylegan2_arch.py#L654
|
||||||
class ConvLayer(torch.nn.Sequential):
|
class ConvLayer(torch.nn.Sequential):
|
||||||
@ -327,6 +332,7 @@ class ConvLayer(torch.nn.Sequential):
|
|||||||
|
|
||||||
super().__init__(*layers)
|
super().__init__(*layers)
|
||||||
|
|
||||||
|
|
||||||
# https://github.com/XPixelGroup/BasicSR/blob/8d56e3a045f9fb3e1d8872f92ee4a4f07f886b0a/basicsr/archs/stylegan2_arch.py#L704
|
# https://github.com/XPixelGroup/BasicSR/blob/8d56e3a045f9fb3e1d8872f92ee4a4f07f886b0a/basicsr/archs/stylegan2_arch.py#L704
|
||||||
class ResBlock(torch.nn.Module):
|
class ResBlock(torch.nn.Module):
|
||||||
def __init__(self, in_channel, out_channel, dtype=None, device=None, operations=None):
|
def __init__(self, in_channel, out_channel, dtype=None, device=None, operations=None):
|
||||||
@ -360,6 +366,7 @@ class EncoderApp(torch.nn.Module):
|
|||||||
h = conv(h)
|
h = conv(h)
|
||||||
return h.squeeze(-1).squeeze(-1)
|
return h.squeeze(-1).squeeze(-1)
|
||||||
|
|
||||||
|
|
||||||
class Encoder(torch.nn.Module):
|
class Encoder(torch.nn.Module):
|
||||||
def __init__(self, dim=512, motion_dim=20, dtype=None, device=None, operations=None):
|
def __init__(self, dim=512, motion_dim=20, dtype=None, device=None, operations=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -369,6 +376,7 @@ class Encoder(torch.nn.Module):
|
|||||||
def encode_motion(self, x):
|
def encode_motion(self, x):
|
||||||
return self.fc(self.net_app(x))
|
return self.fc(self.net_app(x))
|
||||||
|
|
||||||
|
|
||||||
class Direction(torch.nn.Module):
|
class Direction(torch.nn.Module):
|
||||||
def __init__(self, motion_dim, dtype=None, device=None, operations=None):
|
def __init__(self, motion_dim, dtype=None, device=None, operations=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -376,17 +384,19 @@ class Direction(torch.nn.Module):
|
|||||||
self.motion_dim = motion_dim
|
self.motion_dim = motion_dim
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
stabilized_weight = comfy.model_management.cast_to(self.weight, device=input.device, dtype=input.dtype) + 1e-8 * torch.eye(512, self.motion_dim, device=input.device, dtype=input.dtype)
|
stabilized_weight = cast_to(self.weight, device=input.device, dtype=input.dtype) + 1e-8 * torch.eye(512, self.motion_dim, device=input.device, dtype=input.dtype)
|
||||||
Q, _ = torch.linalg.qr(stabilized_weight.float())
|
Q, _ = torch.linalg.qr(stabilized_weight.float())
|
||||||
if input is None:
|
if input is None:
|
||||||
return Q
|
return Q
|
||||||
return torch.sum(input.unsqueeze(-1) * Q.T.to(input.dtype), dim=1)
|
return torch.sum(input.unsqueeze(-1) * Q.T.to(input.dtype), dim=1)
|
||||||
|
|
||||||
|
|
||||||
class Synthesis(torch.nn.Module):
|
class Synthesis(torch.nn.Module):
|
||||||
def __init__(self, motion_dim, dtype=None, device=None, operations=None):
|
def __init__(self, motion_dim, dtype=None, device=None, operations=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.direction = Direction(motion_dim, dtype=dtype, device=device, operations=operations)
|
self.direction = Direction(motion_dim, dtype=dtype, device=device, operations=operations)
|
||||||
|
|
||||||
|
|
||||||
class Generator(torch.nn.Module):
|
class Generator(torch.nn.Module):
|
||||||
def __init__(self, style_dim=512, motion_dim=20, dtype=None, device=None, operations=None):
|
def __init__(self, style_dim=512, motion_dim=20, dtype=None, device=None, operations=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -397,6 +407,7 @@ class Generator(torch.nn.Module):
|
|||||||
motion_feat = self.enc.encode_motion(img)
|
motion_feat = self.enc.encode_motion(img)
|
||||||
return self.dec.direction(motion_feat)
|
return self.dec.direction(motion_feat)
|
||||||
|
|
||||||
|
|
||||||
class AnimateWanModel(WanModel):
|
class AnimateWanModel(WanModel):
|
||||||
r"""
|
r"""
|
||||||
Wan diffusion backbone supporting both text-to-video and image-to-video.
|
Wan diffusion backbone supporting both text-to-video and image-to-video.
|
||||||
@ -481,16 +492,16 @@ class AnimateWanModel(WanModel):
|
|||||||
return x, motion_vec
|
return x, motion_vec
|
||||||
|
|
||||||
def forward_orig(
|
def forward_orig(
|
||||||
self,
|
self,
|
||||||
x,
|
x,
|
||||||
t,
|
t,
|
||||||
context,
|
context,
|
||||||
clip_fea=None,
|
clip_fea=None,
|
||||||
pose_latents=None,
|
pose_latents=None,
|
||||||
face_pixel_values=None,
|
face_pixel_values=None,
|
||||||
freqs=None,
|
freqs=None,
|
||||||
transformer_options={},
|
transformer_options={},
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
# embeddings
|
# embeddings
|
||||||
x = self.patch_embedding(x.float()).to(x.dtype)
|
x = self.patch_embedding(x.float()).to(x.dtype)
|
||||||
@ -529,6 +540,7 @@ class AnimateWanModel(WanModel):
|
|||||||
out = {}
|
out = {}
|
||||||
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"])
|
out["img"] = block(args["img"], context=args["txt"], e=args["vec"], freqs=args["pe"], context_img_len=context_img_len, transformer_options=args["transformer_options"])
|
||||||
return out
|
return out
|
||||||
|
|
||||||
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "vec": e0, "pe": freqs, "transformer_options": transformer_options}, {"original_block": block_wrap})
|
||||||
x = out["img"]
|
x = out["img"]
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -811,12 +811,12 @@ class VAELoader:
|
|||||||
|
|
||||||
# TODO: scale factor?
|
# TODO: scale factor?
|
||||||
def load_vae(self, vae_name):
|
def load_vae(self, vae_name):
|
||||||
|
metadata = {}
|
||||||
if vae_name == "pixel_space":
|
if vae_name == "pixel_space":
|
||||||
sd = {}
|
sd_ = {}
|
||||||
sd["pixel_space_vae"] = torch.tensor(1.0)
|
sd_["pixel_space_vae"] = torch.tensor(1.0)
|
||||||
elif vae_name in ["taesd", "taesdxl", "taesd3", "taef1"]:
|
elif vae_name in ["taesd", "taesdxl", "taesd3", "taef1"]:
|
||||||
sd_ = self.load_taesd(vae_name)
|
sd_ = self.load_taesd(vae_name)
|
||||||
metadata = {}
|
|
||||||
else:
|
else:
|
||||||
vae_path = get_full_path_or_raise("vae", vae_name, KNOWN_VAES)
|
vae_path = get_full_path_or_raise("vae", vae_name, KNOWN_VAES)
|
||||||
sd_, metadata = utils.load_torch_file(vae_path, return_metadata=True)
|
sd_, metadata = utils.load_torch_file(vae_path, return_metadata=True)
|
||||||
|
|||||||
@ -1,10 +1,11 @@
|
|||||||
from .. import sd1_clip
|
|
||||||
from .llama import Qwen25_7BVLI
|
|
||||||
from .qwen_image import QwenImageTokenizer, QwenImageTEModel
|
|
||||||
from transformers import ByT5Tokenizer
|
|
||||||
import os
|
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
from transformers import ByT5Tokenizer
|
||||||
|
|
||||||
|
from .llama import Qwen25_7BVLI
|
||||||
|
from .qwen_image import QwenImageTokenizer, QwenImageTEModel
|
||||||
|
from .t5 import T5
|
||||||
|
from .. import sd1_clip
|
||||||
from ..component_model import files
|
from ..component_model import files
|
||||||
|
|
||||||
|
|
||||||
@ -64,7 +65,7 @@ class ByT5SmallModel(sd1_clip.SDClipModel):
|
|||||||
model_options = {}
|
model_options = {}
|
||||||
textmodel_json_config = files.get_path_as_dict(textmodel_json_config, "byt5_config_small_glyph.json", package=__package__)
|
textmodel_json_config = files.get_path_as_dict(textmodel_json_config, "byt5_config_small_glyph.json", package=__package__)
|
||||||
|
|
||||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, model_options=model_options, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True)
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, model_options=model_options, special_tokens={"end": 1, "pad": 0}, model_class=T5, enable_attention_masks=True, zero_out_masked=True)
|
||||||
|
|
||||||
|
|
||||||
class HunyuanImageTEModel(QwenImageTEModel):
|
class HunyuanImageTEModel(QwenImageTEModel):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user