Merge branch 'comfyanonymous:master' into master

This commit is contained in:
patientx 2024-12-13 11:21:47 +03:00 committed by GitHub
commit 3218ed8559
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
40 changed files with 51 additions and 101 deletions

View File

@ -33,12 +33,12 @@ def pull(repo, remote_name='origin', branch='master'):
user = repo.default_signature user = repo.default_signature
tree = repo.index.write_tree() tree = repo.index.write_tree()
commit = repo.create_commit('HEAD', repo.create_commit('HEAD',
user, user,
user, user,
'Merge!', 'Merge!',
tree, tree,
[repo.head.target, remote_master_id]) [repo.head.target, remote_master_id])
# We need to do this or git CLI will think we are still merging. # We need to do this or git CLI will think we are still merging.
repo.state_cleanup() repo.state_cleanup()
else: else:

View File

@ -40,7 +40,7 @@ class InternalRoutes:
return web.json_response("".join([(l["t"] + " - " + l["m"]) for l in app.logger.get_logs()])) return web.json_response("".join([(l["t"] + " - " + l["m"]) for l in app.logger.get_logs()]))
@self.routes.get('/logs/raw') @self.routes.get('/logs/raw')
async def get_logs(request): async def get_raw_logs(request):
self.terminal_service.update_size() self.terminal_service.update_size()
return web.json_response({ return web.json_response({
"entries": list(app.logger.get_logs()), "entries": list(app.logger.get_logs()),

View File

@ -413,7 +413,6 @@ class ControlNet(nn.Module):
out_output = [] out_output = []
out_middle = [] out_middle = []
hs = []
if self.num_classes is not None: if self.num_classes is not None:
assert y.shape[0] == x.shape[0] assert y.shape[0] == x.shape[0]
emb = emb + self.label_emb(y) emb = emb + self.label_emb(y)

View File

@ -297,7 +297,6 @@ class ControlLoraOps:
class Linear(torch.nn.Module, comfy.ops.CastWeightBiasOp): class Linear(torch.nn.Module, comfy.ops.CastWeightBiasOp):
def __init__(self, in_features: int, out_features: int, bias: bool = True, def __init__(self, in_features: int, out_features: int, bias: bool = True,
device=None, dtype=None) -> None: device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__() super().__init__()
self.in_features = in_features self.in_features = in_features
self.out_features = out_features self.out_features = out_features
@ -382,7 +381,6 @@ class ControlLora(ControlNet):
self.control_model.to(comfy.model_management.get_torch_device()) self.control_model.to(comfy.model_management.get_torch_device())
diffusion_model = model.diffusion_model diffusion_model = model.diffusion_model
sd = diffusion_model.state_dict() sd = diffusion_model.state_dict()
cm = self.control_model.state_dict()
for k in sd: for k in sd:
weight = sd[k] weight = sd[k]
@ -823,7 +821,7 @@ def load_t2i_adapter(t2i_data, model_options={}): #TODO: model_options
for i in range(4): for i in range(4):
for j in range(2): for j in range(2):
prefix_replace["adapter.body.{}.resnets.{}.".format(i, j)] = "body.{}.".format(i * 2 + j) prefix_replace["adapter.body.{}.resnets.{}.".format(i, j)] = "body.{}.".format(i * 2 + j)
prefix_replace["adapter.body.{}.".format(i, j)] = "body.{}.".format(i * 2) prefix_replace["adapter.body.{}.".format(i, )] = "body.{}.".format(i * 2)
prefix_replace["adapter."] = "" prefix_replace["adapter."] = ""
t2i_data = comfy.utils.state_dict_prefix_replace(t2i_data, prefix_replace) t2i_data = comfy.utils.state_dict_prefix_replace(t2i_data, prefix_replace)
keys = t2i_data.keys() keys = t2i_data.keys()

View File

@ -703,7 +703,6 @@ class UniPC:
): ):
# t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end # t_0 = 1. / self.noise_schedule.total_N if t_end is None else t_end
# t_T = self.noise_schedule.T if t_start is None else t_start # t_T = self.noise_schedule.T if t_start is None else t_start
device = x.device
steps = len(timesteps) - 1 steps = len(timesteps) - 1
if method == 'multistep': if method == 'multistep':
assert steps >= order assert steps >= order

View File

@ -1,3 +1,4 @@
import math
import torch import torch
from torch import nn from torch import nn
from .ldm.modules.attention import CrossAttention from .ldm.modules.attention import CrossAttention

View File

@ -130,7 +130,7 @@ class WeightHook(Hook):
weights = self.weights weights = self.weights
else: else:
weights = self.weights_clip weights = self.weights_clip
k = model.add_hook_patches(hook=self, patches=weights, strength_patch=strength) model.add_hook_patches(hook=self, patches=weights, strength_patch=strength)
registered.append(self) registered.append(self)
return True return True
# TODO: add logs about any keys that were not applied # TODO: add logs about any keys that were not applied

View File

@ -11,7 +11,6 @@ import numpy as np
# Transfer from the input time (sigma) used in EDM to that (t) used in DEIS. # Transfer from the input time (sigma) used in EDM to that (t) used in DEIS.
def edm2t(edm_steps, epsilon_s=1e-3, sigma_min=0.002, sigma_max=80): def edm2t(edm_steps, epsilon_s=1e-3, sigma_min=0.002, sigma_max=80):
vp_sigma = lambda beta_d, beta_min: lambda t: (np.e ** (0.5 * beta_d * (t ** 2) + beta_min * t) - 1) ** 0.5
vp_sigma_inv = lambda beta_d, beta_min: lambda sigma: ((beta_min ** 2 + 2 * beta_d * (sigma ** 2 + 1).log()).sqrt() - beta_min) / beta_d vp_sigma_inv = lambda beta_d, beta_min: lambda sigma: ((beta_min ** 2 + 2 * beta_d * (sigma ** 2 + 1).log()).sqrt() - beta_min) / beta_d
vp_beta_d = 2 * (np.log(torch.tensor(sigma_min).cpu() ** 2 + 1) / epsilon_s - np.log(torch.tensor(sigma_max).cpu() ** 2 + 1)) / (epsilon_s - 1) vp_beta_d = 2 * (np.log(torch.tensor(sigma_min).cpu() ** 2 + 1) / epsilon_s - np.log(torch.tensor(sigma_max).cpu() ** 2 + 1)) / (epsilon_s - 1)
vp_beta_min = np.log(torch.tensor(sigma_max).cpu() ** 2 + 1) - 0.5 * vp_beta_d vp_beta_min = np.log(torch.tensor(sigma_max).cpu() ** 2 + 1) - 0.5 * vp_beta_d

View File

@ -97,7 +97,7 @@ def get_activation(activation: Literal["elu", "snake", "none"], antialias=False,
raise ValueError(f"Unknown activation {activation}") raise ValueError(f"Unknown activation {activation}")
if antialias: if antialias:
act = Activation1d(act) act = Activation1d(act) # noqa: F821 Activation1d is not defined
return act return act

View File

@ -158,7 +158,6 @@ class RotaryEmbedding(nn.Module):
def forward(self, t): def forward(self, t):
# device = self.inv_freq.device # device = self.inv_freq.device
device = t.device device = t.device
dtype = t.dtype
# t = t.to(torch.float32) # t = t.to(torch.float32)
@ -170,7 +169,7 @@ class RotaryEmbedding(nn.Module):
if self.scale is None: if self.scale is None:
return freqs, 1. return freqs, 1.
power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base # noqa: F821 seq_len is not defined
scale = comfy.ops.cast_to_input(self.scale, t) ** rearrange(power, 'n -> n 1') scale = comfy.ops.cast_to_input(self.scale, t) ** rearrange(power, 'n -> n 1')
scale = torch.cat((scale, scale), dim = -1) scale = torch.cat((scale, scale), dim = -1)
@ -229,9 +228,9 @@ class FeedForward(nn.Module):
linear_in = GLU(dim, inner_dim, activation, dtype=dtype, device=device, operations=operations) linear_in = GLU(dim, inner_dim, activation, dtype=dtype, device=device, operations=operations)
else: else:
linear_in = nn.Sequential( linear_in = nn.Sequential(
Rearrange('b n d -> b d n') if use_conv else nn.Identity(), rearrange('b n d -> b d n') if use_conv else nn.Identity(),
operations.Linear(dim, inner_dim, bias = not no_bias, dtype=dtype, device=device) if not use_conv else operations.Conv1d(dim, inner_dim, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias, dtype=dtype, device=device), operations.Linear(dim, inner_dim, bias = not no_bias, dtype=dtype, device=device) if not use_conv else operations.Conv1d(dim, inner_dim, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias, dtype=dtype, device=device),
Rearrange('b n d -> b d n') if use_conv else nn.Identity(), rearrange('b n d -> b d n') if use_conv else nn.Identity(),
activation activation
) )
@ -246,9 +245,9 @@ class FeedForward(nn.Module):
self.ff = nn.Sequential( self.ff = nn.Sequential(
linear_in, linear_in,
Rearrange('b d n -> b n d') if use_conv else nn.Identity(), rearrange('b d n -> b n d') if use_conv else nn.Identity(),
linear_out, linear_out,
Rearrange('b n d -> b d n') if use_conv else nn.Identity(), rearrange('b n d -> b d n') if use_conv else nn.Identity(),
) )
def forward(self, x): def forward(self, x):
@ -346,18 +345,13 @@ class Attention(nn.Module):
# determine masking # determine masking
masks = [] masks = []
final_attn_mask = None # The mask that will be applied to the attention matrix, taking all masks into account
if input_mask is not None: if input_mask is not None:
input_mask = rearrange(input_mask, 'b j -> b 1 1 j') input_mask = rearrange(input_mask, 'b j -> b 1 1 j')
masks.append(~input_mask) masks.append(~input_mask)
# Other masks will be added here later # Other masks will be added here later
n = q.shape[-2]
if len(masks) > 0:
final_attn_mask = ~or_reduce(masks)
n, device = q.shape[-2], q.device
causal = self.causal if causal is None else causal causal = self.causal if causal is None else causal

View File

@ -147,7 +147,6 @@ class DoubleAttention(nn.Module):
bsz, seqlen1, _ = c.shape bsz, seqlen1, _ = c.shape
bsz, seqlen2, _ = x.shape bsz, seqlen2, _ = x.shape
seqlen = seqlen1 + seqlen2
cq, ck, cv = self.w1q(c), self.w1k(c), self.w1v(c) cq, ck, cv = self.w1q(c), self.w1k(c), self.w1v(c)
cq = cq.view(bsz, seqlen1, self.n_heads, self.head_dim) cq = cq.view(bsz, seqlen1, self.n_heads, self.head_dim)

View File

@ -461,8 +461,6 @@ class AsymmDiTJoint(nn.Module):
pH, pW = H // self.patch_size, W // self.patch_size pH, pW = H // self.patch_size, W // self.patch_size
x = self.embed_x(x) # (B, N, D), where N = T * H * W / patch_size ** 2 x = self.embed_x(x) # (B, N, D), where N = T * H * W / patch_size ** 2
assert x.ndim == 3 assert x.ndim == 3
B = x.size(0)
pH, pW = H // self.patch_size, W // self.patch_size pH, pW = H // self.patch_size, W // self.patch_size
N = T * pH * pW N = T * pH * pW

View File

@ -164,9 +164,6 @@ class HunYuanControlNet(nn.Module):
), ),
) )
# Image embedding
num_patches = self.x_embedder.num_patches
# HUnYuanDiT Blocks # HUnYuanDiT Blocks
self.blocks = nn.ModuleList( self.blocks = nn.ModuleList(
[ [

View File

@ -248,9 +248,6 @@ class HunYuanDiT(nn.Module):
operations.Linear(hidden_size * 4, hidden_size, bias=True, dtype=dtype, device=device), operations.Linear(hidden_size * 4, hidden_size, bias=True, dtype=dtype, device=device),
) )
# Image embedding
num_patches = self.x_embedder.num_patches
# HUnYuanDiT Blocks # HUnYuanDiT Blocks
self.blocks = nn.ModuleList([ self.blocks = nn.ModuleList([
HunYuanDiTBlock(hidden_size=hidden_size, HunYuanDiTBlock(hidden_size=hidden_size,

View File

@ -1,10 +1,12 @@
import logging
import math
import torch import torch
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Dict, Tuple, Union from typing import Any, Dict, Tuple, Union
from comfy.ldm.modules.distributions.distributions import DiagonalGaussianDistribution from comfy.ldm.modules.distributions.distributions import DiagonalGaussianDistribution
from comfy.ldm.util import instantiate_from_config from comfy.ldm.util import get_obj_from_str, instantiate_from_config
from comfy.ldm.modules.ema import LitEma from comfy.ldm.modules.ema import LitEma
import comfy.ops import comfy.ops
@ -52,7 +54,7 @@ class AbstractAutoencoder(torch.nn.Module):
if self.use_ema: if self.use_ema:
self.model_ema = LitEma(self, decay=ema_decay) self.model_ema = LitEma(self, decay=ema_decay)
logpy.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") logging.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
def get_input(self, batch) -> Any: def get_input(self, batch) -> Any:
raise NotImplementedError() raise NotImplementedError()
@ -68,14 +70,14 @@ class AbstractAutoencoder(torch.nn.Module):
self.model_ema.store(self.parameters()) self.model_ema.store(self.parameters())
self.model_ema.copy_to(self) self.model_ema.copy_to(self)
if context is not None: if context is not None:
logpy.info(f"{context}: Switched to EMA weights") logging.info(f"{context}: Switched to EMA weights")
try: try:
yield None yield None
finally: finally:
if self.use_ema: if self.use_ema:
self.model_ema.restore(self.parameters()) self.model_ema.restore(self.parameters())
if context is not None: if context is not None:
logpy.info(f"{context}: Restored training weights") logging.info(f"{context}: Restored training weights")
def encode(self, *args, **kwargs) -> torch.Tensor: def encode(self, *args, **kwargs) -> torch.Tensor:
raise NotImplementedError("encode()-method of abstract base class called") raise NotImplementedError("encode()-method of abstract base class called")
@ -84,7 +86,7 @@ class AbstractAutoencoder(torch.nn.Module):
raise NotImplementedError("decode()-method of abstract base class called") raise NotImplementedError("decode()-method of abstract base class called")
def instantiate_optimizer_from_config(self, params, lr, cfg): def instantiate_optimizer_from_config(self, params, lr, cfg):
logpy.info(f"loading >>> {cfg['target']} <<< optimizer from config") logging.info(f"loading >>> {cfg['target']} <<< optimizer from config")
return get_obj_from_str(cfg["target"])( return get_obj_from_str(cfg["target"])(
params, lr=lr, **cfg.get("params", dict()) params, lr=lr, **cfg.get("params", dict())
) )
@ -112,7 +114,7 @@ class AutoencodingEngine(AbstractAutoencoder):
self.encoder: torch.nn.Module = instantiate_from_config(encoder_config) self.encoder: torch.nn.Module = instantiate_from_config(encoder_config)
self.decoder: torch.nn.Module = instantiate_from_config(decoder_config) self.decoder: torch.nn.Module = instantiate_from_config(decoder_config)
self.regularization: AbstractRegularizer = instantiate_from_config( self.regularization = instantiate_from_config(
regularizer_config regularizer_config
) )

View File

@ -157,8 +157,6 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
b, _, dim_head = query.shape b, _, dim_head = query.shape
dim_head //= heads dim_head //= heads
scale = dim_head ** -0.5
if skip_reshape: if skip_reshape:
query = query.reshape(b * heads, -1, dim_head) query = query.reshape(b * heads, -1, dim_head)
value = value.reshape(b * heads, -1, dim_head) value = value.reshape(b * heads, -1, dim_head)
@ -177,9 +175,8 @@ def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None,
bytes_per_token = torch.finfo(query.dtype).bits//8 bytes_per_token = torch.finfo(query.dtype).bits//8
batch_x_heads, q_tokens, _ = query.shape batch_x_heads, q_tokens, _ = query.shape
_, _, k_tokens = key.shape _, _, k_tokens = key.shape
qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
mem_free_total, mem_free_torch = model_management.get_free_memory(query.device, True) mem_free_total, _ = model_management.get_free_memory(query.device, True)
kv_chunk_size_min = None kv_chunk_size_min = None
kv_chunk_size = None kv_chunk_size = None
@ -230,7 +227,6 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
scale = dim_head ** -0.5 scale = dim_head ** -0.5
h = heads
if skip_reshape: if skip_reshape:
q, k, v = map( q, k, v = map(
lambda t: t.reshape(b * heads, -1, dim_head), lambda t: t.reshape(b * heads, -1, dim_head),

View File

@ -1,3 +1,4 @@
from functools import partial
from typing import Dict, Optional, List from typing import Dict, Optional, List
import numpy as np import numpy as np

View File

@ -162,7 +162,6 @@ def slice_attention(q, k, v):
mem_free_total = model_management.get_free_memory(q.device) mem_free_total = model_management.get_free_memory(q.device)
gb = 1024 ** 3
tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size() tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
modifier = 3 if q.element_size() == 2 else 2.5 modifier = 3 if q.element_size() == 2 else 2.5
mem_required = tensor_size * modifier mem_required = tensor_size * modifier
@ -218,7 +217,7 @@ def xformers_attention(q, k, v):
try: try:
out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None)
out = out.transpose(1, 2).reshape(B, C, H, W) out = out.transpose(1, 2).reshape(B, C, H, W)
except NotImplementedError as e: except NotImplementedError:
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W) out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
return out return out
@ -233,7 +232,7 @@ def pytorch_attention(q, k, v):
try: try:
out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
out = out.transpose(2, 3).reshape(B, C, H, W) out = out.transpose(2, 3).reshape(B, C, H, W)
except model_management.OOM_EXCEPTION as e: except model_management.OOM_EXCEPTION:
logging.warning("scaled_dot_product_attention OOMed: switched to slice attention") logging.warning("scaled_dot_product_attention OOMed: switched to slice attention")
out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W) out = slice_attention(q.view(B, -1, C), k.view(B, -1, C).transpose(1, 2), v.view(B, -1, C).transpose(1, 2)).reshape(B, C, H, W)
return out return out
@ -546,7 +545,6 @@ class Decoder(nn.Module):
attn_op=AttnBlock, attn_op=AttnBlock,
**ignorekwargs): **ignorekwargs):
super().__init__() super().__init__()
if use_linear_attn: attn_type = "linear"
self.ch = ch self.ch = ch
self.temb_ch = 0 self.temb_ch = 0
self.num_resolutions = len(ch_mult) self.num_resolutions = len(ch_mult)
@ -556,8 +554,7 @@ class Decoder(nn.Module):
self.give_pre_end = give_pre_end self.give_pre_end = give_pre_end
self.tanh_out = tanh_out self.tanh_out = tanh_out
# compute in_ch_mult, block_in and curr_res at lowest res # compute block_in and curr_res at lowest res
in_ch_mult = (1,)+tuple(ch_mult)
block_in = ch*ch_mult[self.num_resolutions-1] block_in = ch*ch_mult[self.num_resolutions-1]
curr_res = resolution // 2**(self.num_resolutions-1) curr_res = resolution // 2**(self.num_resolutions-1)
self.z_shape = (1,z_channels,curr_res,curr_res) self.z_shape = (1,z_channels,curr_res,curr_res)

View File

@ -22,7 +22,6 @@ except ImportError:
from typing import Optional, NamedTuple, List from typing import Optional, NamedTuple, List
from typing_extensions import Protocol from typing_extensions import Protocol
from torch import Tensor
from typing import List from typing import List
from comfy import model_management from comfy import model_management
@ -172,7 +171,7 @@ def _get_attention_scores_no_kv_chunking(
del attn_scores del attn_scores
except model_management.OOM_EXCEPTION: except model_management.OOM_EXCEPTION:
logging.warning("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead") logging.warning("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead")
attn_scores -= attn_scores.max(dim=-1, keepdim=True).values attn_scores -= attn_scores.max(dim=-1, keepdim=True).values # noqa: F821 attn_scores is not defined
torch.exp(attn_scores, out=attn_scores) torch.exp(attn_scores, out=attn_scores)
summed = torch.sum(attn_scores, dim=-1, keepdim=True) summed = torch.sum(attn_scores, dim=-1, keepdim=True)
attn_scores /= summed attn_scores /= summed

View File

@ -133,7 +133,6 @@ class AdamWwithEMAandWings(optim.Optimizer):
exp_avgs = [] exp_avgs = []
exp_avg_sqs = [] exp_avg_sqs = []
ema_params_with_grad = [] ema_params_with_grad = []
state_sums = []
max_exp_avg_sqs = [] max_exp_avg_sqs = []
state_steps = [] state_steps = []
amsgrad = group['amsgrad'] amsgrad = group['amsgrad']

View File

@ -427,7 +427,6 @@ class SVD_img2vid(BaseModel):
latent_image = kwargs.get("concat_latent_image", None) latent_image = kwargs.get("concat_latent_image", None)
noise = kwargs.get("noise", None) noise = kwargs.get("noise", None)
device = kwargs["device"]
if latent_image is None: if latent_image is None:
latent_image = torch.zeros_like(noise) latent_image = torch.zeros_like(noise)
@ -711,8 +710,6 @@ class HunyuanDiT(BaseModel):
width = kwargs.get("width", 768) width = kwargs.get("width", 768)
height = kwargs.get("height", 768) height = kwargs.get("height", 768)
crop_w = kwargs.get("crop_w", 0)
crop_h = kwargs.get("crop_h", 0)
target_width = kwargs.get("target_width", width) target_width = kwargs.get("target_width", width)
target_height = kwargs.get("target_height", height) target_height = kwargs.get("target_height", height)

View File

@ -216,7 +216,6 @@ def detect_unet_config(state_dict, key_prefix):
num_res_blocks = [] num_res_blocks = []
channel_mult = [] channel_mult = []
attention_resolutions = []
transformer_depth = [] transformer_depth = []
transformer_depth_output = [] transformer_depth_output = []
context_dim = None context_dim = None
@ -388,7 +387,6 @@ def convert_config(unet_config):
t_out += [d] * (res + 1) t_out += [d] * (res + 1)
s *= 2 s *= 2
transformer_depth = t_in transformer_depth = t_in
transformer_depth_output = t_out
new_config["transformer_depth"] = t_in new_config["transformer_depth"] = t_in
new_config["transformer_depth_output"] = t_out new_config["transformer_depth_output"] = t_out
new_config["transformer_depth_middle"] = transformer_depth_middle new_config["transformer_depth_middle"] = transformer_depth_middle

View File

@ -525,7 +525,7 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
if vram_set_state == VRAMState.NO_VRAM: if vram_set_state == VRAMState.NO_VRAM:
lowvram_model_memory = 64 * 1024 * 1024 lowvram_model_memory = 64 * 1024 * 1024
cur_loaded_model = loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights) loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights)
current_loaded_models.insert(0, loaded_model) current_loaded_models.insert(0, loaded_model)
return return

View File

@ -113,7 +113,7 @@ class WrapperExecutor:
def _create_next_executor(self) -> 'WrapperExecutor': def _create_next_executor(self) -> 'WrapperExecutor':
new_idx = self.idx + 1 new_idx = self.idx + 1
if new_idx > len(self.wrappers): if new_idx > len(self.wrappers):
raise Exception(f"Wrapper idx exceeded available wrappers; something went very wrong.") raise Exception("Wrapper idx exceeded available wrappers; something went very wrong.")
if self.class_obj is None: if self.class_obj is None:
return WrapperExecutor.new_executor(self.original, self.wrappers, new_idx) return WrapperExecutor.new_executor(self.original, self.wrappers, new_idx)
return WrapperExecutor.new_class_executor(self.original, self.class_obj, self.wrappers, new_idx) return WrapperExecutor.new_class_executor(self.original, self.class_obj, self.wrappers, new_idx)

View File

@ -103,7 +103,6 @@ def cleanup_additional_models(models):
def prepare_sampling(model: 'ModelPatcher', noise_shape, conds): def prepare_sampling(model: 'ModelPatcher', noise_shape, conds):
device = model.load_device
real_model: 'BaseModel' = None real_model: 'BaseModel' = None
models, inference_memory = get_additional_models(conds, model.model_dtype()) models, inference_memory = get_additional_models(conds, model.model_dtype())
models += model.get_nested_additional_models() # TODO: does this require inference_memory update? models += model.get_nested_additional_models() # TODO: does this require inference_memory update?

View File

@ -130,11 +130,6 @@ def can_concat_cond(c1, c2):
return cond_equal_size(c1.conditioning, c2.conditioning) return cond_equal_size(c1.conditioning, c2.conditioning)
def cond_cat(c_list): def cond_cat(c_list):
c_crossattn = []
c_concat = []
c_adm = []
crossattn_max_len = 0
temp = {} temp = {}
for x in c_list: for x in c_list:
for k in x: for k in x:
@ -608,8 +603,6 @@ def pre_run_control(model, conds):
for t in range(len(conds)): for t in range(len(conds)):
x = conds[t] x = conds[t]
timestep_start = None
timestep_end = None
percent_to_timestep_function = lambda a: s.percent_to_sigma(a) percent_to_timestep_function = lambda a: s.percent_to_sigma(a)
if 'control' in x: if 'control' in x:
x['control'].pre_run(model, percent_to_timestep_function) x['control'].pre_run(model, percent_to_timestep_function)

View File

@ -435,7 +435,7 @@ class VAE:
if pixel_samples is None: if pixel_samples is None:
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device) pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
pixel_samples[x:x+batch_number] = out pixel_samples[x:x+batch_number] = out
except model_management.OOM_EXCEPTION as e: except model_management.OOM_EXCEPTION:
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
dims = samples_in.ndim - 2 dims = samples_in.ndim - 2
if dims == 1: if dims == 1:
@ -490,7 +490,7 @@ class VAE:
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device) samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
samples[x:x + batch_number] = out samples[x:x + batch_number] = out
except model_management.OOM_EXCEPTION as e: except model_management.OOM_EXCEPTION:
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.") logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
if len(pixel_samples.shape) == 3: if len(pixel_samples.shape) == 3:
samples = self.encode_tiled_1d(pixel_samples) samples = self.encode_tiled_1d(pixel_samples)
@ -691,7 +691,6 @@ def load_checkpoint(config_path=None, ckpt_path=None, output_vae=True, output_cl
config = yaml.safe_load(stream) config = yaml.safe_load(stream)
model_config_params = config['model']['params'] model_config_params = config['model']['params']
clip_config = model_config_params['cond_stage_config'] clip_config = model_config_params['cond_stage_config']
scale_factor = model_config_params['scale_factor']
if "parameterization" in model_config_params: if "parameterization" in model_config_params:
if model_config_params["parameterization"] == "v": if model_config_params["parameterization"] == "v":

View File

@ -336,7 +336,6 @@ def expand_directory_list(directories):
return list(dirs) return list(dirs)
def bundled_embed(embed, prefix, suffix): #bundled embedding in lora format def bundled_embed(embed, prefix, suffix): #bundled embedding in lora format
i = 0
out_list = [] out_list = []
for k in embed: for k in embed:
if k.startswith(prefix) and k.endswith(suffix): if k.startswith(prefix) and k.endswith(suffix):
@ -392,7 +391,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
embed_out = safe_load_embed_zip(embed_path) embed_out = safe_load_embed_zip(embed_path)
else: else:
embed = torch.load(embed_path, map_location="cpu") embed = torch.load(embed_path, map_location="cpu")
except Exception as e: except Exception:
logging.warning("{}\n\nerror loading embedding, skipping loading: {}".format(traceback.format_exc(), embedding_name)) logging.warning("{}\n\nerror loading embedding, skipping loading: {}".format(traceback.format_exc(), embedding_name))
return None return None

View File

@ -224,7 +224,6 @@ class SDXL(supported_models_base.BASE):
def process_clip_state_dict_for_saving(self, state_dict): def process_clip_state_dict_for_saving(self, state_dict):
replace_prefix = {} replace_prefix = {}
keys_to_replace = {}
state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g") state_dict_g = diffusers_convert.convert_text_enc_state_dict_v20(state_dict, "clip_g")
for k in state_dict: for k in state_dict:
if k.startswith("clip_l"): if k.startswith("clip_l"):
@ -527,7 +526,6 @@ class SD3(supported_models_base.BASE):
clip_l = False clip_l = False
clip_g = False clip_g = False
t5 = False t5 = False
dtype_t5 = None
pref = self.text_encoder_key_prefix[0] pref = self.text_encoder_key_prefix[0]
if "{}clip_l.transformer.text_model.final_layer_norm.weight".format(pref) in state_dict: if "{}clip_l.transformer.text_model.final_layer_norm.weight".format(pref) in state_dict:
clip_l = True clip_l = True

View File

@ -172,7 +172,6 @@ class T5LayerSelfAttention(torch.nn.Module):
# self.dropout = nn.Dropout(config.dropout_rate) # self.dropout = nn.Dropout(config.dropout_rate)
def forward(self, x, mask=None, past_bias=None, optimized_attention=None): def forward(self, x, mask=None, past_bias=None, optimized_attention=None):
normed_hidden_states = self.layer_norm(x)
output, past_bias = self.SelfAttention(self.layer_norm(x), mask=mask, past_bias=past_bias, optimized_attention=optimized_attention) output, past_bias = self.SelfAttention(self.layer_norm(x), mask=mask, past_bias=past_bias, optimized_attention=optimized_attention)
# x = x + self.dropout(attention_output) # x = x + self.dropout(attention_output)
x += output x += output

View File

@ -22,14 +22,15 @@ class CLIPTextEncodeSDXL:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { return {"required": {
"clip": ("CLIP", ),
"width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}), "width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
"height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}), "height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
"crop_w": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION}), "crop_w": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION}),
"crop_h": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION}), "crop_h": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION}),
"target_width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}), "target_width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
"target_height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}), "target_height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
"text_g": ("STRING", {"multiline": True, "dynamicPrompts": True}), "clip": ("CLIP", ), "text_g": ("STRING", {"multiline": True, "dynamicPrompts": True}),
"text_l": ("STRING", {"multiline": True, "dynamicPrompts": True}), "clip": ("CLIP", ), "text_l": ("STRING", {"multiline": True, "dynamicPrompts": True}),
}} }}
RETURN_TYPES = ("CONDITIONING",) RETURN_TYPES = ("CONDITIONING",)
FUNCTION = "encode" FUNCTION = "encode"

View File

@ -35,8 +35,6 @@ class HyperTile:
CATEGORY = "model_patches/unet" CATEGORY = "model_patches/unet"
def patch(self, model, tile_size, swap_size, max_depth, scale_depth): def patch(self, model, tile_size, swap_size, max_depth, scale_depth):
model_channels = model.model.model_config.unet_config["model_channels"]
latent_tile_size = max(32, tile_size) // 8 latent_tile_size = max(32, tile_size) // 8
self.temp = None self.temp = None

View File

@ -240,7 +240,6 @@ class ModelSamplingContinuousV:
def patch(self, model, sampling, sigma_max, sigma_min): def patch(self, model, sampling, sigma_max, sigma_min):
m = model.clone() m = model.clone()
latent_format = None
sigma_data = 1.0 sigma_data = 1.0
if sampling == "v_prediction": if sampling == "v_prediction":
sampling_type = comfy.model_sampling.V_PREDICTION sampling_type = comfy.model_sampling.V_PREDICTION

View File

@ -760,7 +760,7 @@ def validate_prompt(prompt):
if 'class_type' not in prompt[x]: if 'class_type' not in prompt[x]:
error = { error = {
"type": "invalid_prompt", "type": "invalid_prompt",
"message": f"Cannot execute because a node is missing the class_type property.", "message": "Cannot execute because a node is missing the class_type property.",
"details": f"Node ID '#{x}'", "details": f"Node ID '#{x}'",
"extra_info": {} "extra_info": {}
} }

View File

@ -22,7 +22,7 @@ def fix_pytorch_libomp():
if b"libomp140.x86_64.dll" not in contents: if b"libomp140.x86_64.dll" not in contents:
break break
try: try:
mydll = ctypes.cdll.LoadLibrary(test_file) ctypes.cdll.LoadLibrary(test_file)
except FileNotFoundError as e: except FileNotFoundError:
logging.warning("Detected pytorch version with libomp issue, patching.") logging.warning("Detected pytorch version with libomp issue, patching.")
shutil.copyfile(os.path.join(lib_folder, "libiomp5md.dll"), dest) shutil.copyfile(os.path.join(lib_folder, "libiomp5md.dll"), dest)

View File

@ -112,6 +112,7 @@ def cuda_malloc_warning():
logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n") logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
def prompt_worker(q, server): def prompt_worker(q, server):
current_time: float = 0.0
e = execution.PromptExecutor(server, lru_size=args.cache_lru) e = execution.PromptExecutor(server, lru_size=args.cache_lru)
last_gc_collect = 0 last_gc_collect = 0
need_gc = False need_gc = False

View File

@ -237,11 +237,7 @@
"source": [ "source": [
"!npm install -g localtunnel\n", "!npm install -g localtunnel\n",
"\n", "\n",
"import subprocess\n",
"import threading\n", "import threading\n",
"import time\n",
"import socket\n",
"import urllib.request\n",
"\n", "\n",
"def iframe_thread(port):\n", "def iframe_thread(port):\n",
" while True:\n", " while True:\n",
@ -288,8 +284,6 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"import threading\n", "import threading\n",
"import time\n",
"import socket\n",
"def iframe_thread(port):\n", "def iframe_thread(port):\n",
" while True:\n", " while True:\n",
" time.sleep(0.5)\n", " time.sleep(0.5)\n",

View File

@ -4,5 +4,7 @@ lint.ignore = ["ALL"]
# Enable specific rules # Enable specific rules
lint.select = [ lint.select = [
"S307", # suspicious-eval-usage "S307", # suspicious-eval-usage
"F401", # unused-import # The "F" series in Ruff stands for "Pyflakes" rules, which catch various Python syntax errors and undefined names.
# See all rules here: https://docs.astral.sh/ruff/rules/#pyflakes-f
"F",
] ]

View File

@ -563,7 +563,7 @@ class PromptServer():
for x in nodes.NODE_CLASS_MAPPINGS: for x in nodes.NODE_CLASS_MAPPINGS:
try: try:
out[x] = node_info(x) out[x] = node_info(x)
except Exception as e: except Exception:
logging.error(f"[ERROR] An error occurred while retrieving information for the '{x}' node.") logging.error(f"[ERROR] An error occurred while retrieving information for the '{x}' node.")
logging.error(traceback.format_exc()) logging.error(traceback.format_exc())
return web.json_response(out) return web.json_response(out)
@ -584,7 +584,7 @@ class PromptServer():
return web.json_response(self.prompt_queue.get_history(max_items=max_items)) return web.json_response(self.prompt_queue.get_history(max_items=max_items))
@routes.get("/history/{prompt_id}") @routes.get("/history/{prompt_id}")
async def get_history(request): async def get_history_prompt_id(request):
prompt_id = request.match_info.get("prompt_id", None) prompt_id = request.match_info.get("prompt_id", None)
return web.json_response(self.prompt_queue.get_history(prompt_id=prompt_id)) return web.json_response(self.prompt_queue.get_history(prompt_id=prompt_id))
@ -599,8 +599,6 @@ class PromptServer():
@routes.post("/prompt") @routes.post("/prompt")
async def post_prompt(request): async def post_prompt(request):
logging.info("got prompt") logging.info("got prompt")
resp_code = 200
out_string = ""
json_data = await request.json() json_data = await request.json()
json_data = self.trigger_on_prompt(json_data) json_data = self.trigger_on_prompt(json_data)
@ -832,8 +830,8 @@ class PromptServer():
for handler in self.on_prompt_handlers: for handler in self.on_prompt_handlers:
try: try:
json_data = handler(json_data) json_data = handler(json_data)
except Exception as e: except Exception:
logging.warning(f"[ERROR] An error occurred during the on_prompt_handler processing") logging.warning("[ERROR] An error occurred during the on_prompt_handler processing")
logging.warning(traceback.format_exc()) logging.warning(traceback.format_exc())
return json_data return json_data

View File

@ -259,7 +259,7 @@ class TestForLoopOpen:
graph = GraphBuilder() graph = GraphBuilder()
if "initial_value0" in kwargs: if "initial_value0" in kwargs:
remaining = kwargs["initial_value0"] remaining = kwargs["initial_value0"]
while_open = graph.node("TestWhileLoopOpen", condition=remaining, initial_value0=remaining, **{(f"initial_value{i}"): kwargs.get(f"initial_value{i}", None) for i in range(1, NUM_FLOW_SOCKETS)}) graph.node("TestWhileLoopOpen", condition=remaining, initial_value0=remaining, **{(f"initial_value{i}"): kwargs.get(f"initial_value{i}", None) for i in range(1, NUM_FLOW_SOCKETS)})
outputs = [kwargs.get(f"initial_value{i}", None) for i in range(1, NUM_FLOW_SOCKETS)] outputs = [kwargs.get(f"initial_value{i}", None) for i in range(1, NUM_FLOW_SOCKETS)]
return { return {
"result": tuple(["stub", remaining] + outputs), "result": tuple(["stub", remaining] + outputs),