From bddb02660c983669090ddbd66cdd0591e6a1f858 Mon Sep 17 00:00:00 2001 From: City <125218114+city96@users.noreply.github.com> Date: Fri, 20 Dec 2024 21:25:00 +0100 Subject: [PATCH 1/2] Add PixArt model support (#6055) * PixArt initial version * PixArt Diffusers convert logic * pos_emb and interpolation logic * Reduce duplicate code * Formatting * Use optimized attention * Edit empty token logic * Basic PixArt LoRA support * Fix aspect ratio logic * PixArtAlpha text encode with conds * Use same detection key logic for PixArt diffusers --- comfy/ldm/pixart/blocks.py | 382 +++++++++++++++++++++++++++++++ comfy/ldm/pixart/pixart.py | 201 ++++++++++++++++ comfy/ldm/pixart/pixartms.py | 246 ++++++++++++++++++++ comfy/lora.py | 15 +- comfy/model_base.py | 16 ++ comfy/model_detection.py | 34 +++ comfy/sd.py | 6 + comfy/sd1_clip.py | 8 +- comfy/supported_models.py | 34 ++- comfy/text_encoders/pixart_t5.py | 42 ++++ comfy/utils.py | 71 ++++++ comfy_extras/nodes_pixart.py | 24 ++ nodes.py | 5 +- 13 files changed, 1079 insertions(+), 5 deletions(-) create mode 100644 comfy/ldm/pixart/blocks.py create mode 100644 comfy/ldm/pixart/pixart.py create mode 100644 comfy/ldm/pixart/pixartms.py create mode 100644 comfy/text_encoders/pixart_t5.py create mode 100644 comfy_extras/nodes_pixart.py diff --git a/comfy/ldm/pixart/blocks.py b/comfy/ldm/pixart/blocks.py new file mode 100644 index 000000000..7ad2ec29e --- /dev/null +++ b/comfy/ldm/pixart/blocks.py @@ -0,0 +1,382 @@ +# Based on: +# https://github.com/PixArt-alpha/PixArt-alpha [Apache 2.0 license] +# https://github.com/PixArt-alpha/PixArt-sigma [Apache 2.0 license] +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +from comfy import model_management +from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, Mlp, timestep_embedding +from comfy.ldm.modules.attention import optimized_attention + +if model_management.xformers_enabled(): + import xformers.ops + if int((xformers.__version__).split(".")[2]) >= 28: + block_diagonal_mask_from_seqlens = xformers.ops.fmha.attn_bias.BlockDiagonalMask.from_seqlens + else: + block_diagonal_mask_from_seqlens = xformers.ops.fmha.BlockDiagonalMask.from_seqlens + +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + +def t2i_modulate(x, shift, scale): + return x * (1 + scale) + shift + +class MultiHeadCrossAttention(nn.Module): + def __init__(self, d_model, num_heads, attn_drop=0., proj_drop=0., dtype=None, device=None, operations=None, **kwargs): + super(MultiHeadCrossAttention, self).__init__() + assert d_model % num_heads == 0, "d_model must be divisible by num_heads" + + self.d_model = d_model + self.num_heads = num_heads + self.head_dim = d_model // num_heads + + self.q_linear = operations.Linear(d_model, d_model, dtype=dtype, device=device) + self.kv_linear = operations.Linear(d_model, d_model*2, dtype=dtype, device=device) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = operations.Linear(d_model, d_model, dtype=dtype, device=device) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, cond, mask=None): + # query/value: img tokens; key: condition; mask: if padding tokens + B, N, C = x.shape + + q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim) + kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim) + k, v = kv.unbind(2) + + # TODO: xformers needs separate mask logic here + if model_management.xformers_enabled(): + attn_bias = None + if mask is not None: + attn_bias = block_diagonal_mask_from_seqlens([N] * B, mask) + x = xformers.ops.memory_efficient_attention(q, k, v, p=0, attn_bias=attn_bias) + else: + q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v),) + attn_mask = None + if mask is not None and len(mask) > 1: + # Create equivalent of xformer diagonal block mask, still only correct for square masks + # But depth doesn't matter as tensors can expand in that dimension + attn_mask_template = torch.ones( + [q.shape[2] // B, mask[0]], + dtype=torch.bool, + device=q.device + ) + attn_mask = torch.block_diag(attn_mask_template) + + # create a mask on the diagonal for each mask in the batch + for _ in range(B - 1): + attn_mask = torch.block_diag(attn_mask, attn_mask_template) + + x = optimized_attention(q, k, v, self.num_heads, mask=attn_mask, skip_reshape=True) + + x = x.view(B, -1, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class AttentionKVCompress(nn.Module): + """Multi-head Attention block with KV token compression and qk norm.""" + def __init__(self, dim, num_heads=8, qkv_bias=True, sampling='conv', sr_ratio=1, qk_norm=False, dtype=None, device=None, operations=None, **kwargs): + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool: If True, add a learnable bias to query, key, value. + """ + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + + self.qkv = operations.Linear(dim, dim * 3, bias=qkv_bias, dtype=dtype, device=device) + self.proj = operations.Linear(dim, dim, dtype=dtype, device=device) + + self.sampling=sampling # ['conv', 'ave', 'uniform', 'uniform_every'] + self.sr_ratio = sr_ratio + if sr_ratio > 1 and sampling == 'conv': + # Avg Conv Init. + self.sr = operations.Conv2d(dim, dim, groups=dim, kernel_size=sr_ratio, stride=sr_ratio, dtype=dtype, device=device) + # self.sr.weight.data.fill_(1/sr_ratio**2) + # self.sr.bias.data.zero_() + self.norm = operations.LayerNorm(dim, dtype=dtype, device=device) + if qk_norm: + self.q_norm = operations.LayerNorm(dim, dtype=dtype, device=device) + self.k_norm = operations.LayerNorm(dim, dtype=dtype, device=device) + else: + self.q_norm = nn.Identity() + self.k_norm = nn.Identity() + + def downsample_2d(self, tensor, H, W, scale_factor, sampling=None): + if sampling is None or scale_factor == 1: + return tensor + B, N, C = tensor.shape + + if sampling == 'uniform_every': + return tensor[:, ::scale_factor], int(N // scale_factor) + + tensor = tensor.reshape(B, H, W, C).permute(0, 3, 1, 2) + new_H, new_W = int(H / scale_factor), int(W / scale_factor) + new_N = new_H * new_W + + if sampling == 'ave': + tensor = F.interpolate( + tensor, scale_factor=1 / scale_factor, mode='nearest' + ).permute(0, 2, 3, 1) + elif sampling == 'uniform': + tensor = tensor[:, :, ::scale_factor, ::scale_factor].permute(0, 2, 3, 1) + elif sampling == 'conv': + tensor = self.sr(tensor).reshape(B, C, -1).permute(0, 2, 1) + tensor = self.norm(tensor) + else: + raise ValueError + + return tensor.reshape(B, new_N, C).contiguous(), new_N + + def forward(self, x, mask=None, HW=None, block_id=None): + B, N, C = x.shape # 2 4096 1152 + new_N = N + if HW is None: + H = W = int(N ** 0.5) + else: + H, W = HW + qkv = self.qkv(x).reshape(B, N, 3, C) + + q, k, v = qkv.unbind(2) + dtype = q.dtype + q = self.q_norm(q) + k = self.k_norm(k) + + # KV compression + if self.sr_ratio > 1: + k, new_N = self.downsample_2d(k, H, W, self.sr_ratio, sampling=self.sampling) + v, new_N = self.downsample_2d(v, H, W, self.sr_ratio, sampling=self.sampling) + + q = q.reshape(B, N, self.num_heads, C // self.num_heads).to(dtype) + k = k.reshape(B, new_N, self.num_heads, C // self.num_heads).to(dtype) + v = v.reshape(B, new_N, self.num_heads, C // self.num_heads).to(dtype) + + if mask is not None: + raise NotImplementedError("Attn mask logic not added for self attention") + + # This is never called at the moment + # attn_bias = None + # if mask is not None: + # attn_bias = torch.zeros([B * self.num_heads, q.shape[1], k.shape[1]], dtype=q.dtype, device=q.device) + # attn_bias.masked_fill_(mask.squeeze(1).repeat(self.num_heads, 1, 1) == 0, float('-inf')) + + # attention 2 + q, k, v = map(lambda t: t.transpose(1, 2), (q, k, v),) + x = optimized_attention(q, k, v, self.num_heads, mask=None, skip_reshape=True) + + x = x.view(B, N, C) + x = self.proj(x) + return x + + +class FinalLayer(nn.Module): + """ + The final layer of PixArt. + """ + def __init__(self, hidden_size, patch_size, out_channels, dtype=None, device=None, operations=None): + super().__init__() + self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device) + ) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + +class T2IFinalLayer(nn.Module): + """ + The final layer of PixArt. + """ + def __init__(self, hidden_size, patch_size, out_channels, dtype=None, device=None, operations=None): + super().__init__() + self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + self.linear = operations.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device) + self.scale_shift_table = nn.Parameter(torch.randn(2, hidden_size) / hidden_size ** 0.5) + self.out_channels = out_channels + + def forward(self, x, t): + dtype = x.dtype + shift, scale = (self.scale_shift_table[None] + t[:, None]).chunk(2, dim=1) + x = t2i_modulate(self.norm_final(x), shift, scale) + x = self.linear(x.to(dtype)) + return x + + +class MaskFinalLayer(nn.Module): + """ + The final layer of PixArt. + """ + def __init__(self, final_hidden_size, c_emb_size, patch_size, out_channels, dtype=None, device=None, operations=None): + super().__init__() + self.norm_final = operations.LayerNorm(final_hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + self.linear = operations.Linear(final_hidden_size, patch_size * patch_size * out_channels, bias=True, dtype=dtype, device=device) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + operations.Linear(c_emb_size, 2 * final_hidden_size, bias=True, dtype=dtype, device=device) + ) + def forward(self, x, t): + shift, scale = self.adaLN_modulation(t).chunk(2, dim=1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + + +class DecoderLayer(nn.Module): + """ + The final layer of PixArt. + """ + def __init__(self, hidden_size, decoder_hidden_size, dtype=None, device=None, operations=None): + super().__init__() + self.norm_decoder = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + self.linear = operations.Linear(hidden_size, decoder_hidden_size, bias=True, dtype=dtype, device=device) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), + operations.Linear(hidden_size, 2 * hidden_size, bias=True, dtype=dtype, device=device) + ) + def forward(self, x, t): + shift, scale = self.adaLN_modulation(t).chunk(2, dim=1) + x = modulate(self.norm_decoder(x), shift, scale) + x = self.linear(x) + return x + + +class SizeEmbedder(TimestepEmbedder): + """ + Embeds scalar timesteps into vector representations. + """ + def __init__(self, hidden_size, frequency_embedding_size=256, dtype=None, device=None, operations=None): + super().__init__(hidden_size=hidden_size, frequency_embedding_size=frequency_embedding_size, operations=operations) + self.mlp = nn.Sequential( + operations.Linear(frequency_embedding_size, hidden_size, bias=True, dtype=dtype, device=device), + nn.SiLU(), + operations.Linear(hidden_size, hidden_size, bias=True, dtype=dtype, device=device), + ) + self.frequency_embedding_size = frequency_embedding_size + self.outdim = hidden_size + + def forward(self, s, bs): + if s.ndim == 1: + s = s[:, None] + assert s.ndim == 2 + if s.shape[0] != bs: + s = s.repeat(bs//s.shape[0], 1) + assert s.shape[0] == bs + b, dims = s.shape[0], s.shape[1] + s = rearrange(s, "b d -> (b d)") + s_freq = timestep_embedding(s, self.frequency_embedding_size) + s_emb = self.mlp(s_freq.to(s.dtype)) + s_emb = rearrange(s_emb, "(b d) d2 -> b (d d2)", b=b, d=dims, d2=self.outdim) + return s_emb + + +class LabelEmbedder(nn.Module): + """ + Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. + """ + def __init__(self, num_classes, hidden_size, dropout_prob, dtype=None, device=None, operations=None): + super().__init__() + use_cfg_embedding = dropout_prob > 0 + self.embedding_table = operations.Embedding(num_classes + use_cfg_embedding, hidden_size, dtype=dtype, device=device), + self.num_classes = num_classes + self.dropout_prob = dropout_prob + + def token_drop(self, labels, force_drop_ids=None): + """ + Drops labels to enable classifier-free guidance. + """ + if force_drop_ids is None: + drop_ids = torch.rand(labels.shape[0]).cuda() < self.dropout_prob + else: + drop_ids = force_drop_ids == 1 + labels = torch.where(drop_ids, self.num_classes, labels) + return labels + + def forward(self, labels, train, force_drop_ids=None): + use_dropout = self.dropout_prob > 0 + if (train and use_dropout) or (force_drop_ids is not None): + labels = self.token_drop(labels, force_drop_ids) + embeddings = self.embedding_table(labels) + return embeddings + + +class CaptionEmbedder(nn.Module): + """ + Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. + """ + def __init__(self, in_channels, hidden_size, uncond_prob, act_layer=nn.GELU(approximate='tanh'), token_num=120, dtype=None, device=None, operations=None): + super().__init__() + self.y_proj = Mlp( + in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size, act_layer=act_layer, + dtype=dtype, device=device, operations=operations, + ) + self.register_buffer("y_embedding", nn.Parameter(torch.randn(token_num, in_channels) / in_channels ** 0.5)) + self.uncond_prob = uncond_prob + + def token_drop(self, caption, force_drop_ids=None): + """ + Drops labels to enable classifier-free guidance. + """ + if force_drop_ids is None: + drop_ids = torch.rand(caption.shape[0]).cuda() < self.uncond_prob + else: + drop_ids = force_drop_ids == 1 + caption = torch.where(drop_ids[:, None, None, None], self.y_embedding, caption) + return caption + + def forward(self, caption, train, force_drop_ids=None): + if train: + assert caption.shape[2:] == self.y_embedding.shape + use_dropout = self.uncond_prob > 0 + if (train and use_dropout) or (force_drop_ids is not None): + caption = self.token_drop(caption, force_drop_ids) + caption = self.y_proj(caption) + return caption + + +class CaptionEmbedderDoubleBr(nn.Module): + """ + Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. + """ + def __init__(self, in_channels, hidden_size, uncond_prob, act_layer=nn.GELU(approximate='tanh'), token_num=120, dtype=None, device=None, operations=None): + super().__init__() + self.proj = Mlp( + in_features=in_channels, hidden_features=hidden_size, out_features=hidden_size, act_layer=act_layer, + dtype=dtype, device=device, operations=operations, + ) + self.embedding = nn.Parameter(torch.randn(1, in_channels) / 10 ** 0.5) + self.y_embedding = nn.Parameter(torch.randn(token_num, in_channels) / 10 ** 0.5) + self.uncond_prob = uncond_prob + + def token_drop(self, global_caption, caption, force_drop_ids=None): + """ + Drops labels to enable classifier-free guidance. + """ + if force_drop_ids is None: + drop_ids = torch.rand(global_caption.shape[0]).cuda() < self.uncond_prob + else: + drop_ids = force_drop_ids == 1 + global_caption = torch.where(drop_ids[:, None], self.embedding, global_caption) + caption = torch.where(drop_ids[:, None, None, None], self.y_embedding, caption) + return global_caption, caption + + def forward(self, caption, train, force_drop_ids=None): + assert caption.shape[2: ] == self.y_embedding.shape + global_caption = caption.mean(dim=2).squeeze() + use_dropout = self.uncond_prob > 0 + if (train and use_dropout) or (force_drop_ids is not None): + global_caption, caption = self.token_drop(global_caption, caption, force_drop_ids) + y_embed = self.proj(global_caption) + return y_embed, caption diff --git a/comfy/ldm/pixart/pixart.py b/comfy/ldm/pixart/pixart.py new file mode 100644 index 000000000..cd572efce --- /dev/null +++ b/comfy/ldm/pixart/pixart.py @@ -0,0 +1,201 @@ +# Based on: +# https://github.com/PixArt-alpha/PixArt-alpha [Apache 2.0 license] +# https://github.com/PixArt-alpha/PixArt-sigma [Apache 2.0 license] +import torch +import torch.nn as nn + +from .blocks import ( + t2i_modulate, + CaptionEmbedder, + AttentionKVCompress, + MultiHeadCrossAttention, + T2IFinalLayer, +) +from comfy.ldm.modules.diffusionmodules.mmdit import PatchEmbed, TimestepEmbedder, Mlp, get_1d_sincos_pos_embed_from_grid_torch + + +class PixArtBlock(nn.Module): + """ + A PixArt block with adaptive layer norm (adaLN-single) conditioning. + """ + def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0, input_size=None, sampling=None, sr_ratio=1, qk_norm=False, **block_kwargs): + super().__init__() + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.attn = AttentionKVCompress( + hidden_size, num_heads=num_heads, qkv_bias=True, sampling=sampling, sr_ratio=sr_ratio, + qk_norm=qk_norm, **block_kwargs + ) + self.cross_attn = MultiHeadCrossAttention(hidden_size, num_heads, **block_kwargs) + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + # to be compatible with lower version pytorch + approx_gelu = lambda: nn.GELU(approximate="tanh") + self.mlp = Mlp(in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, drop=0) + self.drop_path = nn.Identity() #DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size ** 0.5) + self.sampling = sampling + self.sr_ratio = sr_ratio + + def forward(self, x, y, t, mask=None, **kwargs): + B, N, C = x.shape + + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None] + t.reshape(B, 6, -1)).chunk(6, dim=1) + x = x + self.drop_path(gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa)).reshape(B, N, C)) + x = x + self.cross_attn(x, y, mask) + x = x + self.drop_path(gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp))) + + return x + + +### Core PixArt Model ### +class PixArt(nn.Module): + """ + Diffusion model with a Transformer backbone. + """ + def __init__( + self, + input_size=32, + patch_size=2, + in_channels=4, + hidden_size=1152, + depth=28, + num_heads=16, + mlp_ratio=4.0, + class_dropout_prob=0.1, + pred_sigma=True, + drop_path: float = 0., + caption_channels=4096, + pe_interpolation=1.0, + pe_precision=None, + config=None, + model_max_length=120, + qk_norm=False, + kv_compress_config=None, + **kwargs, + ): + super().__init__() + self.pred_sigma = pred_sigma + self.in_channels = in_channels + self.out_channels = in_channels * 2 if pred_sigma else in_channels + self.patch_size = patch_size + self.num_heads = num_heads + self.pe_interpolation = pe_interpolation + self.pe_precision = pe_precision + self.depth = depth + + self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True) + self.t_embedder = TimestepEmbedder(hidden_size) + num_patches = self.x_embedder.num_patches + self.base_size = input_size // self.patch_size + # Will use fixed sin-cos embedding: + self.register_buffer("pos_embed", torch.zeros(1, num_patches, hidden_size)) + + approx_gelu = lambda: nn.GELU(approximate="tanh") + self.t_block = nn.Sequential( + nn.SiLU(), + nn.Linear(hidden_size, 6 * hidden_size, bias=True) + ) + self.y_embedder = CaptionEmbedder( + in_channels=caption_channels, hidden_size=hidden_size, uncond_prob=class_dropout_prob, + act_layer=approx_gelu, token_num=model_max_length + ) + drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule + self.kv_compress_config = kv_compress_config + if kv_compress_config is None: + self.kv_compress_config = { + 'sampling': None, + 'scale_factor': 1, + 'kv_compress_layer': [], + } + self.blocks = nn.ModuleList([ + PixArtBlock( + hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i], + input_size=(input_size // patch_size, input_size // patch_size), + sampling=self.kv_compress_config['sampling'], + sr_ratio=int( + self.kv_compress_config['scale_factor'] + ) if i in self.kv_compress_config['kv_compress_layer'] else 1, + qk_norm=qk_norm, + ) + for i in range(depth) + ]) + self.final_layer = T2IFinalLayer(hidden_size, patch_size, self.out_channels) + + def forward_raw(self, x, t, y, mask=None, data_info=None): + """ + Original forward pass of PixArt. + x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) + t: (N,) tensor of diffusion timesteps + y: (N, 1, 120, C) tensor of class labels + """ + x = x.to(self.dtype) + timestep = t.to(self.dtype) + y = y.to(self.dtype) + pos_embed = self.pos_embed.to(self.dtype) + x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2 + t = self.t_embedder(timestep.to(x.dtype)) # (N, D) + t0 = self.t_block(t) + y = self.y_embedder(y, self.training) # (N, 1, L, D) + if mask is not None: + if mask.shape[0] != y.shape[0]: + mask = mask.repeat(y.shape[0] // mask.shape[0], 1) + mask = mask.squeeze(1).squeeze(1) + y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1]) + y_lens = mask.sum(dim=1).tolist() + else: + y_lens = [y.shape[2]] * y.shape[0] + y = y.squeeze(1).view(1, -1, x.shape[-1]) + for block in self.blocks: + x = block(x, y, t0, y_lens) # (N, T, D) + x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels) + x = self.unpatchify(x) # (N, out_channels, H, W) + return x + + def forward(self, x, timesteps, context, y=None, **kwargs): + """ + Forward pass that adapts comfy input to original forward function + x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) + timesteps: (N,) tensor of diffusion timesteps + context: (N, 1, 120, C) conditioning + y: extra conditioning. + """ + ## Still accepts the input w/o that dim but returns garbage + if len(context.shape) == 3: + context = context.unsqueeze(1) + + ## run original forward pass + out = self.forward_raw( + x = x.to(self.dtype), + t = timesteps.to(self.dtype), + y = context.to(self.dtype), + ) + + ## only return EPS + out = out.to(torch.float) + eps, _ = out[:, :self.in_channels], out[:, self.in_channels:] + return eps + + def unpatchify(self, x): + """ + x: (N, T, patch_size**2 * C) + imgs: (N, H, W, C) + """ + c = self.out_channels + p = self.x_embedder.patch_size[0] + h = w = int(x.shape[1] ** 0.5) + assert h * w == x.shape[1] + + x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) + x = torch.einsum('nhwpqc->nchpwq', x) + imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p)) + return imgs + +def get_2d_sincos_pos_embed_torch(embed_dim, w, h, pe_interpolation=1.0, base_size=16, device=None, dtype=torch.float32): + grid_h, grid_w = torch.meshgrid( + torch.arange(h, device=device, dtype=dtype) / (h/base_size) / pe_interpolation, + torch.arange(w, device=device, dtype=dtype) / (w/base_size) / pe_interpolation, + indexing='ij' + ) + emb_h = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_h, device=device, dtype=dtype) + emb_w = get_1d_sincos_pos_embed_from_grid_torch(embed_dim // 2, grid_w, device=device, dtype=dtype) + emb = torch.cat([emb_w, emb_h], dim=1) # (H*W, D) + return emb diff --git a/comfy/ldm/pixart/pixartms.py b/comfy/ldm/pixart/pixartms.py new file mode 100644 index 000000000..195063b0a --- /dev/null +++ b/comfy/ldm/pixart/pixartms.py @@ -0,0 +1,246 @@ +# Based on: +# https://github.com/PixArt-alpha/PixArt-alpha [Apache 2.0 license] +# https://github.com/PixArt-alpha/PixArt-sigma [Apache 2.0 license] +import torch +import torch.nn as nn + +from .blocks import ( + t2i_modulate, + CaptionEmbedder, + AttentionKVCompress, + MultiHeadCrossAttention, + T2IFinalLayer, + SizeEmbedder, +) +from comfy.ldm.modules.diffusionmodules.mmdit import TimestepEmbedder, PatchEmbed, Mlp +from .pixart import PixArt, get_2d_sincos_pos_embed_torch + + +class PixArtMSBlock(nn.Module): + """ + A PixArt block with adaptive layer norm zero (adaLN-Zero) conditioning. + """ + def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, drop_path=0., input_size=None, + sampling=None, sr_ratio=1, qk_norm=False, dtype=None, device=None, operations=None, **block_kwargs): + super().__init__() + self.hidden_size = hidden_size + self.norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + self.attn = AttentionKVCompress( + hidden_size, num_heads=num_heads, qkv_bias=True, sampling=sampling, sr_ratio=sr_ratio, + qk_norm=qk_norm, dtype=dtype, device=device, operations=operations, **block_kwargs + ) + self.cross_attn = MultiHeadCrossAttention( + hidden_size, num_heads, dtype=dtype, device=device, operations=operations, **block_kwargs + ) + self.norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) + # to be compatible with lower version pytorch + approx_gelu = lambda: nn.GELU(approximate="tanh") + self.mlp = Mlp( + in_features=hidden_size, hidden_features=int(hidden_size * mlp_ratio), act_layer=approx_gelu, + dtype=dtype, device=device, operations=operations + ) + self.scale_shift_table = nn.Parameter(torch.randn(6, hidden_size) / hidden_size ** 0.5) + + def forward(self, x, y, t, mask=None, HW=None, **kwargs): + B, N, C = x.shape + + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None].to(x.dtype) + t.reshape(B, 6, -1)).chunk(6, dim=1) + x = x + (gate_msa * self.attn(t2i_modulate(self.norm1(x), shift_msa, scale_msa), HW=HW)) + x = x + self.cross_attn(x, y, mask) + x = x + (gate_mlp * self.mlp(t2i_modulate(self.norm2(x), shift_mlp, scale_mlp))) + + return x + + +### Core PixArt Model ### +class PixArtMS(PixArt): + """ + Diffusion model with a Transformer backbone. + """ + def __init__( + self, + input_size=32, + patch_size=2, + in_channels=4, + hidden_size=1152, + depth=28, + num_heads=16, + mlp_ratio=4.0, + class_dropout_prob=0.1, + learn_sigma=True, + pred_sigma=True, + drop_path: float = 0., + caption_channels=4096, + pe_interpolation=None, + pe_precision=None, + config=None, + model_max_length=120, + micro_condition=True, + qk_norm=False, + kv_compress_config=None, + dtype=None, + device=None, + operations=None, + **kwargs, + ): + nn.Module.__init__(self) + self.dtype = dtype + self.pred_sigma = pred_sigma + self.in_channels = in_channels + self.out_channels = in_channels * 2 if pred_sigma else in_channels + self.patch_size = patch_size + self.num_heads = num_heads + self.pe_interpolation = pe_interpolation + self.pe_precision = pe_precision + self.hidden_size = hidden_size + self.depth = depth + + approx_gelu = lambda: nn.GELU(approximate="tanh") + self.t_block = nn.Sequential( + nn.SiLU(), + operations.Linear(hidden_size, 6 * hidden_size, bias=True, dtype=dtype, device=device) + ) + self.x_embedder = PatchEmbed( + patch_size=patch_size, + in_chans=in_channels, + embed_dim=hidden_size, + bias=True, + dtype=dtype, + device=device, + operations=operations + ) + self.t_embedder = TimestepEmbedder( + hidden_size, dtype=dtype, device=device, operations=operations, + ) + self.y_embedder = CaptionEmbedder( + in_channels=caption_channels, hidden_size=hidden_size, uncond_prob=class_dropout_prob, + act_layer=approx_gelu, token_num=model_max_length, + dtype=dtype, device=device, operations=operations, + ) + + self.micro_conditioning = micro_condition + if self.micro_conditioning: + self.csize_embedder = SizeEmbedder(hidden_size//3, dtype=dtype, device=device, operations=operations) + self.ar_embedder = SizeEmbedder(hidden_size//3, dtype=dtype, device=device, operations=operations) + + # For fixed sin-cos embedding: + # num_patches = (input_size // patch_size) * (input_size // patch_size) + # self.base_size = input_size // self.patch_size + # self.register_buffer("pos_embed", torch.zeros(1, num_patches, hidden_size)) + + drop_path = [x.item() for x in torch.linspace(0, drop_path, depth)] # stochastic depth decay rule + if kv_compress_config is None: + kv_compress_config = { + 'sampling': None, + 'scale_factor': 1, + 'kv_compress_layer': [], + } + self.blocks = nn.ModuleList([ + PixArtMSBlock( + hidden_size, num_heads, mlp_ratio=mlp_ratio, drop_path=drop_path[i], + sampling=kv_compress_config['sampling'], + sr_ratio=int(kv_compress_config['scale_factor']) if i in kv_compress_config['kv_compress_layer'] else 1, + qk_norm=qk_norm, + dtype=dtype, + device=device, + operations=operations, + ) + for i in range(depth) + ]) + self.final_layer = T2IFinalLayer( + hidden_size, patch_size, self.out_channels, dtype=dtype, device=device, operations=operations + ) + + def forward_orig(self, x, timestep, y, mask=None, c_size=None, c_ar=None, **kwargs): + """ + Original forward pass of PixArt. + x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) + t: (N,) tensor of diffusion timesteps + y: (N, 1, 120, C) conditioning + ar: (N, 1): aspect ratio + cs: (N ,2) size conditioning for height/width + """ + B, C, H, W = x.shape + c_res = (H + W) // 2 + pe_interpolation = self.pe_interpolation + if pe_interpolation is None or self.pe_precision is not None: + # calculate pe_interpolation on-the-fly + pe_interpolation = round(c_res / (512/8.0), self.pe_precision or 0) + + pos_embed = get_2d_sincos_pos_embed_torch( + self.hidden_size, + h=(H // self.patch_size), + w=(W // self.patch_size), + pe_interpolation=pe_interpolation, + base_size=((round(c_res / 64) * 64) // self.patch_size), + device=x.device, + dtype=x.dtype, + ).unsqueeze(0) + + x = self.x_embedder(x) + pos_embed # (N, T, D), where T = H * W / patch_size ** 2 + t = self.t_embedder(timestep, x.dtype) # (N, D) + + if self.micro_conditioning and (c_size is not None and c_ar is not None): + bs = x.shape[0] + c_size = self.csize_embedder(c_size, bs) # (N, D) + c_ar = self.ar_embedder(c_ar, bs) # (N, D) + t = t + torch.cat([c_size, c_ar], dim=1) + + t0 = self.t_block(t) + y = self.y_embedder(y, self.training) # (N, D) + + if mask is not None: + if mask.shape[0] != y.shape[0]: + mask = mask.repeat(y.shape[0] // mask.shape[0], 1) + mask = mask.squeeze(1).squeeze(1) + y = y.squeeze(1).masked_select(mask.unsqueeze(-1) != 0).view(1, -1, x.shape[-1]) + y_lens = mask.sum(dim=1).tolist() + else: + y_lens = [y.shape[2]] * y.shape[0] + y = y.squeeze(1).view(1, -1, x.shape[-1]) + for block in self.blocks: + x = block(x, y, t0, y_lens, (H, W), **kwargs) # (N, T, D) + + x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels) + x = self.unpatchify(x, H, W) # (N, out_channels, H, W) + + return x + + def forward(self, x, timesteps, context, c_size=None, c_ar=None, **kwargs): + B, C, H, W = x.shape + + # Fallback for missing microconds + if self.micro_conditioning: + if c_size is None: + c_size = torch.tensor([H*8, W*8], dtype=x.dtype, device=x.device).repeat(B, 1) + + if c_ar is None: + c_ar = torch.tensor([H/W], dtype=x.dtype, device=x.device).repeat(B, 1) + + ## Still accepts the input w/o that dim but returns garbage + if len(context.shape) == 3: + context = context.unsqueeze(1) + + ## run original forward pass + out = self.forward_orig(x, timesteps, context, c_size=c_size, c_ar=c_ar) + + ## only return EPS + if self.pred_sigma: + return out[:, :self.in_channels] + return out + + def unpatchify(self, x, h, w): + """ + x: (N, T, patch_size**2 * C) + imgs: (N, H, W, C) + """ + c = self.out_channels + p = self.x_embedder.patch_size[0] + h = h // self.patch_size + w = w // self.patch_size + assert h * w == x.shape[1] + + x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) + x = torch.einsum('nhwpqc->nchpwq', x) + imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p)) + return imgs diff --git a/comfy/lora.py b/comfy/lora.py index c43a6fe47..ec3da6f4c 100644 --- a/comfy/lora.py +++ b/comfy/lora.py @@ -344,7 +344,6 @@ def model_lora_keys_unet(model, key_map={}): key_lora = "lycoris_{}".format(k[:-len(".weight")].replace(".", "_")) #simpletuner lycoris format key_map[key_lora] = to - if isinstance(model, comfy.model_base.AuraFlow): #Diffusers lora AuraFlow diffusers_keys = comfy.utils.auraflow_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.") for k in diffusers_keys: @@ -353,6 +352,20 @@ def model_lora_keys_unet(model, key_map={}): key_lora = "transformer.{}".format(k[:-len(".weight")]) #simpletrainer and probably regular diffusers lora format key_map[key_lora] = to + if isinstance(model, comfy.model_base.PixArt): + diffusers_keys = comfy.utils.pixart_to_diffusers(model.model_config.unet_config, output_prefix="diffusion_model.") + for k in diffusers_keys: + if k.endswith(".weight"): + to = diffusers_keys[k] + key_lora = "transformer.{}".format(k[:-len(".weight")]) #default format + key_map[key_lora] = to + + key_lora = "base_model.model.{}".format(k[:-len(".weight")]) #diffusers training script + key_map[key_lora] = to + + key_lora = "unet.base_model.model.{}".format(k[:-len(".weight")]) #old reference peft script + key_map[key_lora] = to + if isinstance(model, comfy.model_base.HunyuanDiT): for k in sdk: if k.startswith("diffusion_model.") and k.endswith(".weight"): diff --git a/comfy/model_base.py b/comfy/model_base.py index 18c6f2244..99c53e57d 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -26,6 +26,7 @@ from comfy.ldm.modules.diffusionmodules.upscaling import ImageConcatWithNoiseAug from comfy.ldm.modules.diffusionmodules.mmdit import OpenAISignatureMMDITWrapper import comfy.ldm.genmo.joint_model.asymm_models_joint import comfy.ldm.aura.mmdit +import comfy.ldm.pixart.pixartms import comfy.ldm.hydit.models import comfy.ldm.audio.dit import comfy.ldm.audio.embedders @@ -718,6 +719,21 @@ class HunyuanDiT(BaseModel): out['image_meta_size'] = comfy.conds.CONDRegular(torch.FloatTensor([[height, width, target_height, target_width, 0, 0]])) return out +class PixArt(BaseModel): + def __init__(self, model_config, model_type=ModelType.EPS, device=None): + super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.pixart.pixartms.PixArtMS) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + + width = kwargs.get("width", None) + height = kwargs.get("height", None) + if width is not None and height is not None: + out["c_size"] = comfy.conds.CONDRegular(torch.FloatTensor([[height, width]])) + out["c_ar"] = comfy.conds.CONDRegular(torch.FloatTensor([[kwargs.get("aspect_ratio", height/width)]])) + + return out + class Flux(BaseModel): def __init__(self, model_config, model_type=ModelType.FLUX, device=None): super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.flux.model.Flux) diff --git a/comfy/model_detection.py b/comfy/model_detection.py index d6b59fd7d..c53bef5bb 100644 --- a/comfy/model_detection.py +++ b/comfy/model_detection.py @@ -203,11 +203,42 @@ def detect_unet_config(state_dict, key_prefix): dit_config["rope_theta"] = 10000.0 return dit_config + if '{}adaln_single.emb.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys and '{}pos_embed.proj.bias'.format(key_prefix) in state_dict_keys: + # PixArt diffusers + return None + if '{}adaln_single.emb.timestep_embedder.linear_1.bias'.format(key_prefix) in state_dict_keys: #Lightricks ltxv dit_config = {} dit_config["image_model"] = "ltxv" return dit_config + if '{}t_block.1.weight'.format(key_prefix) in state_dict_keys: # PixArt + patch_size = 2 + dit_config = {} + dit_config["num_heads"] = 16 + dit_config["patch_size"] = patch_size + dit_config["hidden_size"] = 1152 + dit_config["in_channels"] = 4 + dit_config["depth"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.') + + y_key = "{}y_embedder.y_embedding".format(key_prefix) + if y_key in state_dict_keys: + dit_config["model_max_length"] = state_dict[y_key].shape[0] + + pe_key = "{}pos_embed".format(key_prefix) + if pe_key in state_dict_keys: + dit_config["input_size"] = int(math.sqrt(state_dict[pe_key].shape[1])) * patch_size + dit_config["pe_interpolation"] = dit_config["input_size"] // (512//8) # guess + + ar_key = "{}ar_embedder.mlp.0.weight".format(key_prefix) + if ar_key in state_dict_keys: + dit_config["image_model"] = "pixart_alpha" + dit_config["micro_condition"] = True + else: + dit_config["image_model"] = "pixart_sigma" + dit_config["micro_condition"] = False + return dit_config + if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys: return None @@ -573,6 +604,9 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""): num_joint = count_blocks(state_dict, 'joint_transformer_blocks.{}.') num_single = count_blocks(state_dict, 'single_transformer_blocks.{}.') sd_map = comfy.utils.auraflow_to_diffusers({"n_double_layers": num_joint, "n_layers": num_joint + num_single}, output_prefix=output_prefix) + elif 'adaln_single.emb.timestep_embedder.linear_1.bias' in state_dict and 'pos_embed.proj.bias' in state_dict: # PixArt + num_blocks = count_blocks(state_dict, 'transformer_blocks.{}.') + sd_map = comfy.utils.pixart_to_diffusers({"depth": num_blocks}, output_prefix=output_prefix) elif 'x_embedder.weight' in state_dict: #Flux depth = count_blocks(state_dict, 'transformer_blocks.{}.') depth_single_blocks = count_blocks(state_dict, 'single_transformer_blocks.{}.') diff --git a/comfy/sd.py b/comfy/sd.py index dee8e9849..fef382946 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -27,6 +27,7 @@ import comfy.text_encoders.sd2_clip import comfy.text_encoders.sd3_clip import comfy.text_encoders.sa_t5 import comfy.text_encoders.aura_t5 +import comfy.text_encoders.pixart_t5 import comfy.text_encoders.hydit import comfy.text_encoders.flux import comfy.text_encoders.long_clipl @@ -604,6 +605,8 @@ class CLIPType(Enum): MOCHI = 7 LTXV = 8 HUNYUAN_VIDEO = 9 + PIXART = 10 + def load_clip(ckpt_paths, embedding_directory=None, clip_type=CLIPType.STABLE_DIFFUSION, model_options={}): clip_data = [] @@ -696,6 +699,9 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip elif clip_type == CLIPType.LTXV: clip_target.clip = comfy.text_encoders.lt.ltxv_te(**t5xxl_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.lt.LTXVT5Tokenizer + elif clip_type == CLIPType.PIXART: + clip_target.clip = comfy.text_encoders.pixart_t5.pixart_te(**t5xxl_detect(clip_data)) + clip_target.tokenizer = comfy.text_encoders.pixart_t5.PixArtTokenizer else: #CLIPType.MOCHI clip_target.clip = comfy.text_encoders.genmo.mochi_te(**t5xxl_detect(clip_data)) clip_target.tokenizer = comfy.text_encoders.genmo.MochiT5Tokenizer diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index c0fe1ba52..4845406de 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -37,8 +37,12 @@ class ClipTokenWeightEncoder: sections = len(to_encode) if has_weights or sections == 0: - to_encode.append(gen_empty_tokens(self.special_tokens, max_token_len)) - + if hasattr(self, "gen_empty_tokens"): + to_encode.append(self.gen_empty_tokens(self.special_tokens, max_token_len)) + else: + to_encode.append(gen_empty_tokens(self.special_tokens, max_token_len)) + print(to_encode) + o = self.encode(to_encode) out, pooled = o[:2] diff --git a/comfy/supported_models.py b/comfy/supported_models.py index 68e2b13fa..a5f38b5ed 100644 --- a/comfy/supported_models.py +++ b/comfy/supported_models.py @@ -8,6 +8,7 @@ import comfy.text_encoders.sd2_clip import comfy.text_encoders.sd3_clip import comfy.text_encoders.sa_t5 import comfy.text_encoders.aura_t5 +import comfy.text_encoders.pixart_t5 import comfy.text_encoders.hydit import comfy.text_encoders.flux import comfy.text_encoders.genmo @@ -592,6 +593,37 @@ class AuraFlow(supported_models_base.BASE): def clip_target(self, state_dict={}): return supported_models_base.ClipTarget(comfy.text_encoders.aura_t5.AuraT5Tokenizer, comfy.text_encoders.aura_t5.AuraT5Model) +class PixArtAlpha(supported_models_base.BASE): + unet_config = { + "image_model": "pixart_alpha", + } + + sampling_settings = { + "beta_schedule" : "sqrt_linear", + "linear_start" : 0.0001, + "linear_end" : 0.02, + "timesteps" : 1000, + } + + unet_extra_config = {} + latent_format = latent_formats.SD15 + + vae_key_prefix = ["vae."] + text_encoder_key_prefix = ["text_encoders."] + + def get_model(self, state_dict, prefix="", device=None): + out = model_base.PixArt(self, device=device) + return out.eval() + + def clip_target(self, state_dict={}): + return supported_models_base.ClipTarget(comfy.text_encoders.pixart_t5.PixArtTokenizer, comfy.text_encoders.pixart_t5.PixArtT5XXL) + +class PixArtSigma(PixArtAlpha): + unet_config = { + "image_model": "pixart_sigma", + } + latent_format = latent_formats.SDXL + class HunyuanDiT(supported_models_base.BASE): unet_config = { "image_model": "hydit", @@ -787,6 +819,6 @@ class HunyuanVideo(supported_models_base.BASE): hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}llama.transformer.".format(pref)) return supported_models_base.ClipTarget(comfy.text_encoders.hunyuan_video.HunyuanVideoTokenizer, comfy.text_encoders.hunyuan_video.hunyuan_video_clip(**hunyuan_detect)) -models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo] +models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV, HunyuanVideo] models += [SVD_img2vid] diff --git a/comfy/text_encoders/pixart_t5.py b/comfy/text_encoders/pixart_t5.py new file mode 100644 index 000000000..d56d57f1b --- /dev/null +++ b/comfy/text_encoders/pixart_t5.py @@ -0,0 +1,42 @@ +import os + +from comfy import sd1_clip +import comfy.text_encoders.t5 +import comfy.text_encoders.sd3_clip +from comfy.sd1_clip import gen_empty_tokens + +from transformers import T5TokenizerFast + +class T5XXLModel(comfy.text_encoders.sd3_clip.T5XXLModel): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def gen_empty_tokens(self, special_tokens, *args, **kwargs): + # PixArt expects the negative to be all pad tokens + special_tokens = special_tokens.copy() + special_tokens.pop("end") + return gen_empty_tokens(special_tokens, *args, **kwargs) + +class PixArtT5XXL(sd1_clip.SD1ClipModel): + def __init__(self, device="cpu", dtype=None, model_options={}): + super().__init__(device=device, dtype=dtype, name="t5xxl", clip_model=T5XXLModel, model_options=model_options) + +class T5XXLTokenizer(sd1_clip.SDTokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "t5_tokenizer") + super().__init__(tokenizer_path, embedding_directory=embedding_directory, pad_with_end=False, embedding_size=4096, embedding_key='t5xxl', tokenizer_class=T5TokenizerFast, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1) # no padding + +class PixArtTokenizer(sd1_clip.SD1Tokenizer): + def __init__(self, embedding_directory=None, tokenizer_data={}): + super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="t5xxl", tokenizer=T5XXLTokenizer) + +def pixart_te(dtype_t5=None, t5xxl_scaled_fp8=None): + class PixArtTEModel_(PixArtT5XXL): + def __init__(self, device="cpu", dtype=None, model_options={}): + if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options: + model_options = model_options.copy() + model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8 + if dtype is None: + dtype = dtype_t5 + super().__init__(device=device, dtype=dtype, model_options=model_options) + return PixArtTEModel_ diff --git a/comfy/utils.py b/comfy/utils.py index 3ddbfd90c..5fb5418b5 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -386,6 +386,77 @@ def mmdit_to_diffusers(mmdit_config, output_prefix=""): return key_map +PIXART_MAP_BASIC = { + ("csize_embedder.mlp.0.weight", "adaln_single.emb.resolution_embedder.linear_1.weight"), + ("csize_embedder.mlp.0.bias", "adaln_single.emb.resolution_embedder.linear_1.bias"), + ("csize_embedder.mlp.2.weight", "adaln_single.emb.resolution_embedder.linear_2.weight"), + ("csize_embedder.mlp.2.bias", "adaln_single.emb.resolution_embedder.linear_2.bias"), + ("ar_embedder.mlp.0.weight", "adaln_single.emb.aspect_ratio_embedder.linear_1.weight"), + ("ar_embedder.mlp.0.bias", "adaln_single.emb.aspect_ratio_embedder.linear_1.bias"), + ("ar_embedder.mlp.2.weight", "adaln_single.emb.aspect_ratio_embedder.linear_2.weight"), + ("ar_embedder.mlp.2.bias", "adaln_single.emb.aspect_ratio_embedder.linear_2.bias"), + ("x_embedder.proj.weight", "pos_embed.proj.weight"), + ("x_embedder.proj.bias", "pos_embed.proj.bias"), + ("y_embedder.y_embedding", "caption_projection.y_embedding"), + ("y_embedder.y_proj.fc1.weight", "caption_projection.linear_1.weight"), + ("y_embedder.y_proj.fc1.bias", "caption_projection.linear_1.bias"), + ("y_embedder.y_proj.fc2.weight", "caption_projection.linear_2.weight"), + ("y_embedder.y_proj.fc2.bias", "caption_projection.linear_2.bias"), + ("t_embedder.mlp.0.weight", "adaln_single.emb.timestep_embedder.linear_1.weight"), + ("t_embedder.mlp.0.bias", "adaln_single.emb.timestep_embedder.linear_1.bias"), + ("t_embedder.mlp.2.weight", "adaln_single.emb.timestep_embedder.linear_2.weight"), + ("t_embedder.mlp.2.bias", "adaln_single.emb.timestep_embedder.linear_2.bias"), + ("t_block.1.weight", "adaln_single.linear.weight"), + ("t_block.1.bias", "adaln_single.linear.bias"), + ("final_layer.linear.weight", "proj_out.weight"), + ("final_layer.linear.bias", "proj_out.bias"), + ("final_layer.scale_shift_table", "scale_shift_table"), +} + +PIXART_MAP_BLOCK = { + ("scale_shift_table", "scale_shift_table"), + ("attn.proj.weight", "attn1.to_out.0.weight"), + ("attn.proj.bias", "attn1.to_out.0.bias"), + ("mlp.fc1.weight", "ff.net.0.proj.weight"), + ("mlp.fc1.bias", "ff.net.0.proj.bias"), + ("mlp.fc2.weight", "ff.net.2.weight"), + ("mlp.fc2.bias", "ff.net.2.bias"), + ("cross_attn.proj.weight" ,"attn2.to_out.0.weight"), + ("cross_attn.proj.bias" ,"attn2.to_out.0.bias"), +} + +def pixart_to_diffusers(mmdit_config, output_prefix=""): + key_map = {} + + depth = mmdit_config.get("depth", 0) + offset = mmdit_config.get("hidden_size", 1152) + + for i in range(depth): + block_from = "transformer_blocks.{}".format(i) + block_to = "{}blocks.{}".format(output_prefix, i) + + for end in ("weight", "bias"): + s = "{}.attn1.".format(block_from) + qkv = "{}.attn.qkv.{}".format(block_to, end) + key_map["{}to_q.{}".format(s, end)] = (qkv, (0, 0, offset)) + key_map["{}to_k.{}".format(s, end)] = (qkv, (0, offset, offset)) + key_map["{}to_v.{}".format(s, end)] = (qkv, (0, offset * 2, offset)) + + s = "{}.attn2.".format(block_from) + q = "{}.cross_attn.q_linear.{}".format(block_to, end) + kv = "{}.cross_attn.kv_linear.{}".format(block_to, end) + + key_map["{}to_q.{}".format(s, end)] = q + key_map["{}to_k.{}".format(s, end)] = (kv, (0, 0, offset)) + key_map["{}to_v.{}".format(s, end)] = (kv, (0, offset, offset)) + + for k in PIXART_MAP_BLOCK: + key_map["{}.{}".format(block_from, k[1])] = "{}.{}".format(block_to, k[0]) + + for k in PIXART_MAP_BASIC: + key_map[k[1]] = "{}{}".format(output_prefix, k[0]) + + return key_map def auraflow_to_diffusers(mmdit_config, output_prefix=""): n_double_layers = mmdit_config.get("n_double_layers", 0) diff --git a/comfy_extras/nodes_pixart.py b/comfy_extras/nodes_pixart.py new file mode 100644 index 000000000..c7209c468 --- /dev/null +++ b/comfy_extras/nodes_pixart.py @@ -0,0 +1,24 @@ +from nodes import MAX_RESOLUTION + +class CLIPTextEncodePixArtAlpha: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}), + "height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}), + # "aspect_ratio": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 10.0, "step": 0.01}), + "text": ("STRING", {"multiline": True, "dynamicPrompts": True}), "clip": ("CLIP", ), + }} + + RETURN_TYPES = ("CONDITIONING",) + FUNCTION = "encode" + CATEGORY = "advanced/conditioning" + DESCRIPTION = "Encodes text and sets the resolution conditioning for PixArt Alpha. Does not apply to PixArt Sigma." + + def encode(self, clip, width, height, text): + tokens = clip.tokenize(text) + return (clip.encode_from_tokens_scheduled(tokens, add_dict={"width": width, "height": height}),) + +NODE_CLASS_MAPPINGS = { + "CLIPTextEncodePixArtAlpha": CLIPTextEncodePixArtAlpha, +} diff --git a/nodes.py b/nodes.py index 1a90073e9..bdea7564b 100644 --- a/nodes.py +++ b/nodes.py @@ -898,7 +898,7 @@ class CLIPLoader: @classmethod def INPUT_TYPES(s): return {"required": { "clip_name": (folder_paths.get_filename_list("text_encoders"), ), - "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv"], ), + "type": (["stable_diffusion", "stable_cascade", "sd3", "stable_audio", "mochi", "ltxv", "pixart"], ), }} RETURN_TYPES = ("CLIP",) FUNCTION = "load_clip" @@ -918,6 +918,8 @@ class CLIPLoader: clip_type = comfy.sd.CLIPType.MOCHI elif type == "ltxv": clip_type = comfy.sd.CLIPType.LTXV + elif type == "pixart": + clip_type = comfy.sd.CLIPType.PIXART else: clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION @@ -2164,6 +2166,7 @@ def init_builtin_extra_nodes(): "nodes_stable3d.py", "nodes_sdupscale.py", "nodes_photomaker.py", + "nodes_pixart.py", "nodes_cond.py", "nodes_morphology.py", "nodes_stable_cascade.py", From d7969cb070f9a59663d0b4ce7aabe7d49e236fc3 Mon Sep 17 00:00:00 2001 From: Chenlei Hu Date: Fri, 20 Dec 2024 13:24:55 -0800 Subject: [PATCH 2/2] Replace print with logging (#6138) * Replace print with logging * nit * nit * nit * nit * nit * nit --- .ci/update_windows/update.py | 14 +++++++------- app/user_manager.py | 4 ++-- comfy/cldm/cldm.py | 1 - comfy/extra_samplers/uni_pc.py | 4 ++-- comfy/hooks.py | 3 ++- comfy/ldm/aura/mmdit.py | 1 - comfy/ldm/modules/diffusionmodules/util.py | 7 ++++--- comfy/ldm/util.py | 5 +++-- comfy/model_base.py | 1 - comfy/model_management.py | 2 +- comfy/model_patcher.py | 6 +++--- comfy/sd.py | 4 ++-- comfy/sd1_clip.py | 3 +-- comfy_extras/chainner_models/model_loading.py | 3 ++- comfy_extras/nodes_hooks.py | 5 +++-- main.py | 8 ++++---- new_updater.py | 2 +- ruff.toml | 5 ++++- tests-unit/server/routes/internal_routes_test.py | 4 ++-- tests/conftest.py | 2 +- tests/inference/test_execution.py | 6 +++--- tests/inference/test_inference.py | 4 ++-- 22 files changed, 49 insertions(+), 45 deletions(-) diff --git a/.ci/update_windows/update.py b/.ci/update_windows/update.py index 59bee9804..731b6bc53 100755 --- a/.ci/update_windows/update.py +++ b/.ci/update_windows/update.py @@ -28,7 +28,7 @@ def pull(repo, remote_name='origin', branch='master'): if repo.index.conflicts is not None: for conflict in repo.index.conflicts: - print('Conflicts found in:', conflict[0].path) + print('Conflicts found in:', conflict[0].path) # noqa: T201 raise AssertionError('Conflicts, ahhhhh!!') user = repo.default_signature @@ -49,18 +49,18 @@ repo_path = str(sys.argv[1]) repo = pygit2.Repository(repo_path) ident = pygit2.Signature('comfyui', 'comfy@ui') try: - print("stashing current changes") + print("stashing current changes") # noqa: T201 repo.stash(ident) except KeyError: - print("nothing to stash") + print("nothing to stash") # noqa: T201 backup_branch_name = 'backup_branch_{}'.format(datetime.today().strftime('%Y-%m-%d_%H_%M_%S')) -print("creating backup branch: {}".format(backup_branch_name)) +print("creating backup branch: {}".format(backup_branch_name)) # noqa: T201 try: repo.branches.local.create(backup_branch_name, repo.head.peel()) except: pass -print("checking out master branch") +print("checking out master branch") # noqa: T201 branch = repo.lookup_branch('master') if branch is None: ref = repo.lookup_reference('refs/remotes/origin/master') @@ -72,7 +72,7 @@ else: ref = repo.lookup_reference(branch.name) repo.checkout(ref) -print("pulling latest changes") +print("pulling latest changes") # noqa: T201 pull(repo) if "--stable" in sys.argv: @@ -94,7 +94,7 @@ if "--stable" in sys.argv: if latest_tag is not None: repo.checkout(latest_tag) -print("Done!") +print("Done!") # noqa: T201 self_update = True if len(sys.argv) > 2: diff --git a/app/user_manager.py b/app/user_manager.py index e863b93dd..e7381e621 100644 --- a/app/user_manager.py +++ b/app/user_manager.py @@ -38,8 +38,8 @@ class UserManager(): if not os.path.exists(user_directory): os.makedirs(user_directory, exist_ok=True) if not args.multi_user: - print("****** User settings have been changed to be stored on the server instead of browser storage. ******") - print("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******") + logging.warning("****** User settings have been changed to be stored on the server instead of browser storage. ******") + logging.warning("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******") if args.multi_user: if os.path.isfile(self.get_users_file()): diff --git a/comfy/cldm/cldm.py b/comfy/cldm/cldm.py index f12cd6eeb..ec01665e2 100644 --- a/comfy/cldm/cldm.py +++ b/comfy/cldm/cldm.py @@ -160,7 +160,6 @@ class ControlNet(nn.Module): if isinstance(self.num_classes, int): self.label_emb = nn.Embedding(num_classes, time_embed_dim) elif self.num_classes == "continuous": - print("setting up linear c_adm embedding layer") self.label_emb = nn.Linear(1, time_embed_dim) elif self.num_classes == "sequential": assert adm_in_channels is not None diff --git a/comfy/extra_samplers/uni_pc.py b/comfy/extra_samplers/uni_pc.py index 18ff92663..b61baaa8e 100644 --- a/comfy/extra_samplers/uni_pc.py +++ b/comfy/extra_samplers/uni_pc.py @@ -2,6 +2,7 @@ import torch import math +import logging from tqdm.auto import trange @@ -474,7 +475,7 @@ class UniPC: return self.multistep_uni_pc_vary_update(x, model_prev_list, t_prev_list, t, order, **kwargs) def multistep_uni_pc_vary_update(self, x, model_prev_list, t_prev_list, t, order, use_corrector=True): - print(f'using unified predictor-corrector with order {order} (solver type: vary coeff)') + logging.info(f'using unified predictor-corrector with order {order} (solver type: vary coeff)') ns = self.noise_schedule assert order <= len(model_prev_list) @@ -518,7 +519,6 @@ class UniPC: A_p = C_inv_p if use_corrector: - print('using corrector') C_inv = torch.linalg.inv(C) A_c = C_inv diff --git a/comfy/hooks.py b/comfy/hooks.py index 356b7d65b..b6f0ac213 100644 --- a/comfy/hooks.py +++ b/comfy/hooks.py @@ -5,6 +5,7 @@ import math import torch import numpy as np import itertools +import logging if TYPE_CHECKING: from comfy.model_patcher import ModelPatcher, PatcherInjection @@ -575,7 +576,7 @@ def load_hook_lora_for_models(model: 'ModelPatcher', clip: 'CLIP', lora: dict[st k1 = set(k1) for x in loaded: if (x not in k) and (x not in k1): - print(f"NOT LOADED {x}") + logging.warning(f"NOT LOADED {x}") return (new_modelpatcher, new_clip, hook_group) def _combine_hooks_from_values(c_dict: dict[str, HookGroup], values: dict[str, HookGroup], cache: dict[tuple[HookGroup, HookGroup], HookGroup]): diff --git a/comfy/ldm/aura/mmdit.py b/comfy/ldm/aura/mmdit.py index 7792151aa..1258ae11f 100644 --- a/comfy/ldm/aura/mmdit.py +++ b/comfy/ldm/aura/mmdit.py @@ -381,7 +381,6 @@ class MMDiT(nn.Module): pe_new = pe_as_2d.squeeze(0).permute(1, 2, 0).flatten(0, 1) self.positional_encoding.data = pe_new.unsqueeze(0).contiguous() self.h_max, self.w_max = target_dim - print("PE extended to", target_dim) def pe_selection_index_based_on_dim(self, h, w): h_p, w_p = h // self.patch_size, w // self.patch_size diff --git a/comfy/ldm/modules/diffusionmodules/util.py b/comfy/ldm/modules/diffusionmodules/util.py index 9377b0737..233011dc9 100644 --- a/comfy/ldm/modules/diffusionmodules/util.py +++ b/comfy/ldm/modules/diffusionmodules/util.py @@ -9,6 +9,7 @@ import math +import logging import torch import torch.nn as nn import numpy as np @@ -130,7 +131,7 @@ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timestep # add one to get the final alpha values right (the ones from first scale to data during sampling) steps_out = ddim_timesteps + 1 if verbose: - print(f'Selected timesteps for ddim sampler: {steps_out}') + logging.info(f'Selected timesteps for ddim sampler: {steps_out}') return steps_out @@ -142,8 +143,8 @@ def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): # according the the formula provided in https://arxiv.org/abs/2010.02502 sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) if verbose: - print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') - print(f'For the chosen value of eta, which is {eta}, ' + logging.info(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') + logging.info(f'For the chosen value of eta, which is {eta}, ' f'this results in the following sigma_t schedule for ddim sampler {sigmas}') return sigmas, alphas, alphas_prev diff --git a/comfy/ldm/util.py b/comfy/ldm/util.py index fdd8b84a2..2ed4aa2ab 100644 --- a/comfy/ldm/util.py +++ b/comfy/ldm/util.py @@ -1,4 +1,5 @@ import importlib +import logging import torch from torch import optim @@ -23,7 +24,7 @@ def log_txt_as_img(wh, xc, size=10): try: draw.text((0, 0), lines, fill="black", font=font) except UnicodeEncodeError: - print("Cant encode string for logging. Skipping.") + logging.warning("Cant encode string for logging. Skipping.") txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 txts.append(txt) @@ -65,7 +66,7 @@ def mean_flat(tensor): def count_params(model, verbose=False): total_params = sum(p.numel() for p in model.parameters()) if verbose: - print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") + logging.info(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") return total_params diff --git a/comfy/model_base.py b/comfy/model_base.py index 99c53e57d..af3f0f147 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -770,7 +770,6 @@ class Flux(BaseModel): mask = torch.ones_like(noise)[:, :1] mask = torch.mean(mask, dim=1, keepdim=True) - print(mask.shape) mask = utils.common_upscale(mask.to(device), noise.shape[-1] * 8, noise.shape[-2] * 8, "bilinear", "center") mask = mask.view(mask.shape[0], mask.shape[2] // 8, 8, mask.shape[3] // 8, 8).permute(0, 2, 4, 1, 3).reshape(mask.shape[0], -1, mask.shape[2] // 8, mask.shape[3] // 8) mask = utils.resize_to_batch_size(mask, noise.shape[0]) diff --git a/comfy/model_management.py b/comfy/model_management.py index 6f667dfc5..2cbdc7392 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -1084,7 +1084,7 @@ def unload_all_models(): def resolve_lowvram_weight(weight, model, key): #TODO: remove - print("WARNING: The comfy.model_management.resolve_lowvram_weight function will be removed soon, please stop using it.") + logging.warning("The comfy.model_management.resolve_lowvram_weight function will be removed soon, please stop using it.") return weight #TODO: might be cleaner to put this somewhere else diff --git a/comfy/model_patcher.py b/comfy/model_patcher.py index fb651242b..13684da7e 100644 --- a/comfy/model_patcher.py +++ b/comfy/model_patcher.py @@ -773,7 +773,7 @@ class ModelPatcher: return self.model.device def calculate_weight(self, patches, weight, key, intermediate_dtype=torch.float32): - print("WARNING the ModelPatcher.calculate_weight function is deprecated, please use: comfy.lora.calculate_weight instead") + logging.warning("The ModelPatcher.calculate_weight function is deprecated, please use: comfy.lora.calculate_weight instead") return comfy.lora.calculate_weight(patches, weight, key, intermediate_dtype=intermediate_dtype) def cleanup(self): @@ -1029,7 +1029,7 @@ class ModelPatcher: if cached_weights is not None: for key in cached_weights: if key not in model_sd_keys: - print(f"WARNING cached hook could not patch. key does not exist in model: {key}") + logging.warning(f"Cached hook could not patch. Key does not exist in model: {key}") continue self.patch_cached_hook_weights(cached_weights=cached_weights, key=key, memory_counter=memory_counter) else: @@ -1039,7 +1039,7 @@ class ModelPatcher: original_weights = self.get_key_patches() for key in relevant_patches: if key not in model_sd_keys: - print(f"WARNING cached hook would not patch. key does not exist in model: {key}") + logging.warning(f"Cached hook would not patch. Key does not exist in model: {key}") continue self.patch_hook_weight_to_device(hooks=hooks, combined_patches=relevant_patches, key=key, original_weights=original_weights, memory_counter=memory_counter) diff --git a/comfy/sd.py b/comfy/sd.py index fef382946..f79eacc24 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -940,11 +940,11 @@ def load_diffusion_model(unet_path, model_options={}): return model def load_unet(unet_path, dtype=None): - print("WARNING: the load_unet function has been deprecated and will be removed please switch to: load_diffusion_model") + logging.warning("The load_unet function has been deprecated and will be removed please switch to: load_diffusion_model") return load_diffusion_model(unet_path, model_options={"dtype": dtype}) def load_unet_state_dict(sd, dtype=None): - print("WARNING: the load_unet_state_dict function has been deprecated and will be removed please switch to: load_diffusion_model_state_dict") + logging.warning("The load_unet_state_dict function has been deprecated and will be removed please switch to: load_diffusion_model_state_dict") return load_diffusion_model_state_dict(sd, model_options={"dtype": dtype}) def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, metadata=None, extra_keys={}): diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index 4845406de..95d41c30f 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -41,8 +41,7 @@ class ClipTokenWeightEncoder: to_encode.append(self.gen_empty_tokens(self.special_tokens, max_token_len)) else: to_encode.append(gen_empty_tokens(self.special_tokens, max_token_len)) - print(to_encode) - + o = self.encode(to_encode) out, pooled = o[:2] diff --git a/comfy_extras/chainner_models/model_loading.py b/comfy_extras/chainner_models/model_loading.py index d48bc238c..1bec4476f 100644 --- a/comfy_extras/chainner_models/model_loading.py +++ b/comfy_extras/chainner_models/model_loading.py @@ -1,5 +1,6 @@ +import logging from spandrel import ModelLoader def load_state_dict(state_dict): - print("WARNING: comfy_extras.chainner_models is deprecated and has been replaced by the spandrel library.") + logging.warning("comfy_extras.chainner_models is deprecated and has been replaced by the spandrel library.") return ModelLoader().load_from_state_dict(state_dict).eval() diff --git a/comfy_extras/nodes_hooks.py b/comfy_extras/nodes_hooks.py index d0cb69902..27fe3c423 100644 --- a/comfy_extras/nodes_hooks.py +++ b/comfy_extras/nodes_hooks.py @@ -1,5 +1,6 @@ from __future__ import annotations from typing import TYPE_CHECKING, Union +import logging import torch from collections.abc import Iterable @@ -539,7 +540,7 @@ class CreateHookKeyframesInterpolated: is_first = False prev_hook_kf.add(comfy.hooks.HookKeyframe(strength=strength, start_percent=percent, guarantee_steps=guarantee_steps)) if print_keyframes: - print(f"Hook Keyframe - start_percent:{percent} = {strength}") + logging.info(f"Hook Keyframe - start_percent:{percent} = {strength}") return (prev_hook_kf,) class CreateHookKeyframesFromFloats: @@ -588,7 +589,7 @@ class CreateHookKeyframesFromFloats: is_first = False prev_hook_kf.add(comfy.hooks.HookKeyframe(strength=strength, start_percent=percent, guarantee_steps=guarantee_steps)) if print_keyframes: - print(f"Hook Keyframe - start_percent:{percent} = {strength}") + logging.info(f"Hook Keyframe - start_percent:{percent} = {strength}") return (prev_hook_kf,) #------------------------------------------ ########################################### diff --git a/main.py b/main.py index 9d1632633..b65046535 100644 --- a/main.py +++ b/main.py @@ -63,7 +63,7 @@ def execute_prestartup_script(): spec.loader.exec_module(module) return True except Exception as e: - print(f"Failed to execute startup-script: {script_path} / {e}") + logging.error(f"Failed to execute startup-script: {script_path} / {e}") return False if args.disable_all_custom_nodes: @@ -85,14 +85,14 @@ def execute_prestartup_script(): success = execute_script(script_path) node_prestartup_times.append((time.perf_counter() - time_before, module_path, success)) if len(node_prestartup_times) > 0: - print("\nPrestartup times for custom nodes:") + logging.info("\nPrestartup times for custom nodes:") for n in sorted(node_prestartup_times): if n[2]: import_message = "" else: import_message = " (PRESTARTUP FAILED)" - print("{:6.1f} seconds{}:".format(n[0], import_message), n[1]) - print() + logging.info("{:6.1f} seconds{}: {}".format(n[0], import_message, n[1])) + logging.info("") apply_custom_paths() execute_prestartup_script() diff --git a/new_updater.py b/new_updater.py index a49e0877c..9a203acdd 100644 --- a/new_updater.py +++ b/new_updater.py @@ -32,4 +32,4 @@ def update_windows_updater(): except: pass shutil.copy(bat_path, dest_bat_path) - print("Updated the windows standalone package updater.") + print("Updated the windows standalone package updater.") # noqa: T201 diff --git a/ruff.toml b/ruff.toml index a83d450b1..c354505f8 100644 --- a/ruff.toml +++ b/ruff.toml @@ -4,7 +4,10 @@ lint.ignore = ["ALL"] # Enable specific rules lint.select = [ "S307", # suspicious-eval-usage + "T201", # print-usage # The "F" series in Ruff stands for "Pyflakes" rules, which catch various Python syntax errors and undefined names. # See all rules here: https://docs.astral.sh/ruff/rules/#pyflakes-f "F", -] \ No newline at end of file +] + +exclude = ["*.ipynb"] diff --git a/tests-unit/server/routes/internal_routes_test.py b/tests-unit/server/routes/internal_routes_test.py index 4fe544249..68c846652 100644 --- a/tests-unit/server/routes/internal_routes_test.py +++ b/tests-unit/server/routes/internal_routes_test.py @@ -89,9 +89,9 @@ async def test_routes_added_to_app(aiohttp_client_factory, internal_routes): client = await aiohttp_client_factory() try: resp = await client.get('/files') - print(f"Response received: status {resp.status}") + print(f"Response received: status {resp.status}") # noqa: T201 except Exception as e: - print(f"Exception occurred during GET request: {e}") + print(f"Exception occurred during GET request: {e}") # noqa: T201 raise assert resp.status != 404, "Route /files does not exist" diff --git a/tests/conftest.py b/tests/conftest.py index 1a35880af..bddfb6e15 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -28,7 +28,7 @@ def pytest_collection_modifyitems(items): last_items = [] for test_name in LAST_TESTS: for item in items.copy(): - print(item.module.__name__, item) + print(item.module.__name__, item) # noqa: T201 if item.module.__name__ == test_name: last_items.append(item) items.remove(item) diff --git a/tests/inference/test_execution.py b/tests/inference/test_execution.py index 3909ca68d..ca880abd2 100644 --- a/tests/inference/test_execution.py +++ b/tests/inference/test_execution.py @@ -134,7 +134,7 @@ class TestExecution: use_lru, lru_size = request.param if use_lru: pargs += ['--cache-lru', str(lru_size)] - print("Running server with args:", pargs) + print("Running server with args:", pargs) # noqa: T201 p = subprocess.Popen(pargs) yield p.kill() @@ -150,8 +150,8 @@ class TestExecution: try: comfy_client.connect(listen=listen, port=port) except ConnectionRefusedError as e: - print(e) - print(f"({i+1}/{n_tries}) Retrying...") + print(e) # noqa: T201 + print(f"({i+1}/{n_tries}) Retrying...") # noqa: T201 else: break return comfy_client diff --git a/tests/inference/test_inference.py b/tests/inference/test_inference.py index 1db3c06fb..d9a20c475 100644 --- a/tests/inference/test_inference.py +++ b/tests/inference/test_inference.py @@ -171,8 +171,8 @@ class TestInference: try: comfy_client.connect(listen=listen, port=port) except ConnectionRefusedError as e: - print(e) - print(f"({i+1}/{n_tries}) Retrying...") + print(e) # noqa: T201 + print(f"({i+1}/{n_tries}) Retrying...") # noqa: T201 else: break return comfy_client