mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-24 21:30:15 +08:00
Compare commits
15 Commits
02529c6d57
...
72ca18acc2
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
72ca18acc2 | ||
|
|
f588e6c821 | ||
|
|
0da072e098 | ||
|
|
31d358c78c | ||
|
|
4dd42ef1b7 | ||
|
|
acbf08cd60 | ||
|
|
53e762a3af | ||
|
|
9a552df898 | ||
|
|
f2fda021ab | ||
|
|
303b1735f8 | ||
|
|
9e5f677746 | ||
|
|
65cfcf5b1b | ||
|
|
1bdc9a947f | ||
|
|
d622a61874 | ||
|
|
236b9e211d |
@ -44,7 +44,7 @@ class ModelFileManager:
|
||||
@routes.get("/experiment/models/{folder}")
|
||||
async def get_all_models(request):
|
||||
folder = request.match_info.get("folder", None)
|
||||
if not folder in folder_paths.folder_names_and_paths:
|
||||
if folder not in folder_paths.folder_names_and_paths:
|
||||
return web.Response(status=404)
|
||||
files = self.get_model_file_list(folder)
|
||||
return web.json_response(files)
|
||||
@ -55,7 +55,7 @@ class ModelFileManager:
|
||||
path_index = int(request.match_info.get("path_index", None))
|
||||
filename = request.match_info.get("filename", None)
|
||||
|
||||
if not folder_name in folder_paths.folder_names_and_paths:
|
||||
if folder_name not in folder_paths.folder_names_and_paths:
|
||||
return web.Response(status=404)
|
||||
|
||||
folders = folder_paths.folder_names_and_paths[folder_name]
|
||||
|
||||
@ -2,6 +2,25 @@ import torch
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
import comfy.ops
|
||||
|
||||
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True):
|
||||
image = image[:, :, :, :3] if image.shape[3] > 3 else image
|
||||
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
|
||||
std = torch.tensor(std, device=image.device, dtype=image.dtype)
|
||||
image = image.movedim(-1, 1)
|
||||
if not (image.shape[2] == size and image.shape[3] == size):
|
||||
if crop:
|
||||
scale = (size / min(image.shape[2], image.shape[3]))
|
||||
scale_size = (round(scale * image.shape[2]), round(scale * image.shape[3]))
|
||||
else:
|
||||
scale_size = (size, size)
|
||||
|
||||
image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True)
|
||||
h = (image.shape[2] - size)//2
|
||||
w = (image.shape[3] - size)//2
|
||||
image = image[:,:,h:h+size,w:w+size]
|
||||
image = torch.clip((255. * image), 0, 255).round() / 255.0
|
||||
return (image - mean.view([3,1,1])) / std.view([3,1,1])
|
||||
|
||||
class CLIPAttention(torch.nn.Module):
|
||||
def __init__(self, embed_dim, heads, dtype, device, operations):
|
||||
super().__init__()
|
||||
|
||||
@ -1,6 +1,5 @@
|
||||
from .utils import load_torch_file, transformers_convert, state_dict_prefix_replace
|
||||
import os
|
||||
import torch
|
||||
import json
|
||||
import logging
|
||||
|
||||
@ -17,24 +16,7 @@ class Output:
|
||||
def __setitem__(self, key, item):
|
||||
setattr(self, key, item)
|
||||
|
||||
def clip_preprocess(image, size=224, mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711], crop=True):
|
||||
image = image[:, :, :, :3] if image.shape[3] > 3 else image
|
||||
mean = torch.tensor(mean, device=image.device, dtype=image.dtype)
|
||||
std = torch.tensor(std, device=image.device, dtype=image.dtype)
|
||||
image = image.movedim(-1, 1)
|
||||
if not (image.shape[2] == size and image.shape[3] == size):
|
||||
if crop:
|
||||
scale = (size / min(image.shape[2], image.shape[3]))
|
||||
scale_size = (round(scale * image.shape[2]), round(scale * image.shape[3]))
|
||||
else:
|
||||
scale_size = (size, size)
|
||||
|
||||
image = torch.nn.functional.interpolate(image, size=scale_size, mode="bicubic", antialias=True)
|
||||
h = (image.shape[2] - size)//2
|
||||
w = (image.shape[3] - size)//2
|
||||
image = image[:,:,h:h+size,w:w+size]
|
||||
image = torch.clip((255. * image), 0, 255).round() / 255.0
|
||||
return (image - mean.view([3,1,1])) / std.view([3,1,1])
|
||||
clip_preprocess = comfy.clip_model.clip_preprocess # Prevent some stuff from breaking, TODO: remove eventually
|
||||
|
||||
IMAGE_ENCODERS = {
|
||||
"clip_vision_model": comfy.clip_model.CLIPVisionModelProjection,
|
||||
@ -73,7 +55,7 @@ class ClipVisionModel():
|
||||
|
||||
def encode_image(self, image, crop=True):
|
||||
comfy.model_management.load_model_gpu(self.patcher)
|
||||
pixel_values = clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
|
||||
pixel_values = comfy.clip_model.clip_preprocess(image.to(self.load_device), size=self.image_size, mean=self.image_mean, std=self.image_std, crop=crop).float()
|
||||
out = self.model(pixel_values=pixel_values, intermediate_output='all' if self.return_all_hidden_states else -2)
|
||||
|
||||
outputs = Output()
|
||||
|
||||
@ -527,7 +527,8 @@ class HookKeyframeGroup:
|
||||
if self._current_keyframe.get_effective_guarantee_steps(max_sigma) > 0:
|
||||
break
|
||||
# if eval_c is outside the percent range, stop looking further
|
||||
else: break
|
||||
else:
|
||||
break
|
||||
# update steps current context is used
|
||||
self._current_used_steps += 1
|
||||
# update current timestep this was performed on
|
||||
|
||||
@ -270,7 +270,7 @@ class ChromaRadiance(Chroma):
|
||||
bad_keys = tuple(
|
||||
k
|
||||
for k, v in overrides.items()
|
||||
if type(v) != type(getattr(params, k)) and (v is not None or k not in nullable_keys)
|
||||
if not isinstance(v, type(getattr(params, k))) and (v is not None or k not in nullable_keys)
|
||||
)
|
||||
if bad_keys:
|
||||
e = f"Invalid value(s) in transformer_options chroma_radiance_options: {', '.join(bad_keys)}"
|
||||
|
||||
@ -3,7 +3,8 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, VideoConv3d
|
||||
from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm
|
||||
import model_management, model_patcher
|
||||
import model_management
|
||||
import model_patcher
|
||||
|
||||
class SRResidualCausalBlock3D(nn.Module):
|
||||
def __init__(self, channels: int):
|
||||
|
||||
@ -98,7 +98,7 @@ def var_attn_arg(kwargs):
|
||||
cu_seqlens_k = kwargs.get("cu_seqlens_k", cu_seqlens_q)
|
||||
max_seqlen_q = kwargs.get("max_seqlen_q", None)
|
||||
max_seqlen_k = kwargs.get("max_seqlen_k", max_seqlen_q)
|
||||
assert cu_seqlens_q != None, "cu_seqlens_q shouldn't be None when var_length is True"
|
||||
assert cu_seqlens_q is not None, "cu_seqlens_q shouldn't be None when var_length is True"
|
||||
return cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k
|
||||
# feedforward
|
||||
class GEGLU(nn.Module):
|
||||
@ -423,7 +423,8 @@ except:
|
||||
pass
|
||||
|
||||
@wrap_attn
|
||||
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, var_length=False, **kwargs):
|
||||
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
var_length = kwargs.get("var_length", False)
|
||||
b = q.shape[0]
|
||||
dim_head = q.shape[-1]
|
||||
# check to make sure xformers isn't broken
|
||||
@ -448,9 +449,12 @@ def attention_xformers(q, k, v, heads, mask=None, attn_precision=None, skip_resh
|
||||
k = k.view(1, total_tokens, heads, dim_head)
|
||||
v = v.view(1, total_tokens, heads, dim_head)
|
||||
else:
|
||||
if q.ndim == 3: q = q.unsqueeze(0)
|
||||
if k.ndim == 3: k = k.unsqueeze(0)
|
||||
if v.ndim == 3: v = v.unsqueeze(0)
|
||||
if q.ndim == 3:
|
||||
q = q.unsqueeze(0)
|
||||
if k.ndim == 3:
|
||||
k = k.unsqueeze(0)
|
||||
if v.ndim == 3:
|
||||
v = v.unsqueeze(0)
|
||||
dim_head = q.shape[-1]
|
||||
|
||||
target_output_shape = (q.shape[1], -1)
|
||||
@ -513,7 +517,8 @@ else:
|
||||
SDP_BATCH_LIMIT = 2**31
|
||||
|
||||
@wrap_attn
|
||||
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, var_length=False, **kwargs):
|
||||
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
var_length = kwargs.get("var_length", False)
|
||||
if var_length:
|
||||
cu_seqlens_q, cu_seqlens_k, _, _ = var_attn_arg(kwargs)
|
||||
if not skip_reshape:
|
||||
@ -524,7 +529,8 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
||||
k = k.view(k.shape[0], heads, head_dim)
|
||||
v = v.view(v.shape[0], heads, head_dim)
|
||||
|
||||
b = q.size(0); dim_head = q.shape[-1]
|
||||
b = q.size(0)
|
||||
dim_head = q.shape[-1]
|
||||
q = torch.nested.nested_tensor_from_jagged(q, offsets=cu_seqlens_q.long())
|
||||
k = torch.nested.nested_tensor_from_jagged(k, offsets=cu_seqlens_k.long())
|
||||
v = torch.nested.nested_tensor_from_jagged(v, offsets=cu_seqlens_k.long())
|
||||
@ -577,7 +583,8 @@ def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None, skip_resha
|
||||
return out
|
||||
|
||||
@wrap_attn
|
||||
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, var_length = False, **kwargs):
|
||||
def attention_sage(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
var_length = kwargs.get("var_length", False)
|
||||
exception_fallback = False
|
||||
if var_length:
|
||||
if not skip_reshape:
|
||||
@ -750,7 +757,8 @@ except AttributeError as error:
|
||||
assert False, f"Could not define flash_attn_wrapper: {FLASH_ATTN_ERROR}"
|
||||
|
||||
@wrap_attn
|
||||
def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, var_length=False, **kwargs):
|
||||
def attention_flash(q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False, skip_output_reshape=False, **kwargs):
|
||||
var_length = kwargs.get("var_length", False)
|
||||
if var_length:
|
||||
cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k = var_attn_arg(kwargs)
|
||||
return flash_attn_varlen_func(
|
||||
|
||||
@ -397,7 +397,8 @@ class Model(nn.Module):
|
||||
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
||||
resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"):
|
||||
super().__init__()
|
||||
if use_linear_attn: attn_type = "linear"
|
||||
if use_linear_attn:
|
||||
attn_type = "linear"
|
||||
self.ch = ch
|
||||
self.temb_ch = self.ch*4
|
||||
self.num_resolutions = len(ch_mult)
|
||||
@ -551,7 +552,8 @@ class Encoder(nn.Module):
|
||||
conv3d=False, time_compress=None,
|
||||
**ignore_kwargs):
|
||||
super().__init__()
|
||||
if use_linear_attn: attn_type = "linear"
|
||||
if use_linear_attn:
|
||||
attn_type = "linear"
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
|
||||
@ -45,7 +45,7 @@ class LitEma(nn.Module):
|
||||
shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
|
||||
shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
|
||||
else:
|
||||
assert not key in self.m_name2s_name
|
||||
assert key not in self.m_name2s_name
|
||||
|
||||
def copy_to(self, model):
|
||||
m_param = dict(model.named_parameters())
|
||||
@ -54,7 +54,7 @@ class LitEma(nn.Module):
|
||||
if m_param[key].requires_grad:
|
||||
m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
|
||||
else:
|
||||
assert not key in self.m_name2s_name
|
||||
assert key not in self.m_name2s_name
|
||||
|
||||
def store(self, parameters):
|
||||
"""
|
||||
|
||||
@ -13,6 +13,7 @@ from comfy.rmsnorm import RMSNorm
|
||||
from torch.nn.modules.utils import _triple
|
||||
from torch import nn
|
||||
import math
|
||||
from comfy.ldm.flux.math import apply_rope1
|
||||
|
||||
class Cache:
|
||||
def __init__(self, disable=False, prefix="", cache=None):
|
||||
@ -443,7 +444,6 @@ def apply_rotary_emb(
|
||||
freqs_seq_dim = None
|
||||
):
|
||||
dtype = t.dtype
|
||||
|
||||
if not exists(freqs_seq_dim):
|
||||
if freqs.ndim == 2 or t.ndim == 3:
|
||||
freqs_seq_dim = 0
|
||||
@ -452,20 +452,23 @@ def apply_rotary_emb(
|
||||
seq_len = t.shape[seq_dim]
|
||||
freqs = slice_at_dim(freqs, slice(-seq_len, None), dim = freqs_seq_dim)
|
||||
|
||||
rot_dim = freqs.shape[-1]
|
||||
end_index = start_index + rot_dim
|
||||
|
||||
assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}'
|
||||
rot_feats = freqs.shape[-1]
|
||||
end_index = start_index + rot_feats
|
||||
|
||||
t_left = t[..., :start_index]
|
||||
t_middle = t[..., start_index:end_index]
|
||||
t_right = t[..., end_index:]
|
||||
|
||||
freqs = freqs.to(t_middle.device)
|
||||
t_transformed = (t_middle * freqs.cos() * scale) + (rotate_half(t_middle) * freqs.sin() * scale)
|
||||
angles = freqs.to(t_middle.device)[..., ::2]
|
||||
cos = torch.cos(angles) * scale
|
||||
sin = torch.sin(angles) * scale
|
||||
|
||||
out = torch.cat((t_left, t_transformed, t_right), dim=-1)
|
||||
col0 = torch.stack([cos, sin], dim=-1)
|
||||
col1 = torch.stack([-sin, cos], dim=-1)
|
||||
freqs_mat = torch.stack([col0, col1], dim=-1)
|
||||
|
||||
t_middle_out = apply_rope1(t_middle, freqs_mat)
|
||||
out = torch.cat((t_left, t_middle_out, t_right), dim=-1)
|
||||
return out.type(dtype)
|
||||
|
||||
class NaMMRotaryEmbedding3d(MMRotaryEmbeddingBase):
|
||||
|
||||
@ -16,6 +16,7 @@ import math
|
||||
from enum import Enum
|
||||
from comfy.ops import NVIDIA_MEMORY_CONV_BUG_WORKAROUND
|
||||
|
||||
import logging
|
||||
import comfy.ops
|
||||
ops = comfy.ops.disable_weight_init
|
||||
|
||||
@ -446,6 +447,7 @@ class InflatedCausalConv3d(ops.Conv3d):
|
||||
self.memory_device = memory_device
|
||||
self.padding = (0, *self.padding[1:])
|
||||
self.memory_limit = float("inf")
|
||||
self.logged_once = False
|
||||
|
||||
def set_memory_limit(self, value: float):
|
||||
self.memory_limit = value
|
||||
@ -469,8 +471,16 @@ class InflatedCausalConv3d(ops.Conv3d):
|
||||
return out
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
return super()._conv_forward(input, weight, bias, *args, **kwargs)
|
||||
except NotImplementedError:
|
||||
pass
|
||||
try:
|
||||
return super()._conv_forward(input, weight, bias, *args, **kwargs)
|
||||
except NotImplementedError:
|
||||
# for: Could not run 'aten::cudnn_convolution' with arguments from the 'CPU' backend
|
||||
if not self.logged_once:
|
||||
logging.warning("VAE is on CPU for decoding. This is most likely due to not enough memory")
|
||||
self.logged_once = True
|
||||
return F.conv3d(input, weight, bias, *args, **kwargs)
|
||||
|
||||
def memory_limit_conv(
|
||||
self,
|
||||
|
||||
@ -71,7 +71,7 @@ def count_params(model, verbose=False):
|
||||
|
||||
|
||||
def instantiate_from_config(config):
|
||||
if not "target" in config:
|
||||
if "target" not in config:
|
||||
if config == '__is_first_stage__':
|
||||
return None
|
||||
elif config == "__is_unconditional__":
|
||||
|
||||
@ -1542,6 +1542,10 @@ def soft_empty_cache(force=False):
|
||||
def unload_all_models():
|
||||
free_memory(1e30, get_torch_device())
|
||||
|
||||
def debug_memory_summary():
|
||||
if is_amd() or is_nvidia():
|
||||
return torch.cuda.memory.memory_summary()
|
||||
return ""
|
||||
|
||||
#TODO: might be cleaner to put this somewhere else
|
||||
import threading
|
||||
|
||||
@ -382,8 +382,8 @@ class VAE:
|
||||
self.latent_channels = 16
|
||||
elif "decoder.up_blocks.2.upsamplers.0.upscale_conv.weight" in sd: # seedvr2
|
||||
self.first_stage_model = comfy.ldm.seedvr.vae.VideoAutoencoderKLWrapper()
|
||||
self.memory_used_decode = lambda shape, dtype: (10 * shape[1] * shape[2] * shape[3] * (4 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||
self.memory_used_encode = lambda shape, dtype: (10 * max(shape[1], 5) * shape[2] * shape[3]) * model_management.dtype_size(dtype)
|
||||
self.memory_used_decode = lambda shape, dtype: (shape[1] * shape[2] * shape[3] * (4 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||
self.memory_used_encode = lambda shape, dtype: (max(shape[1], 5) * shape[2] * shape[3]) * model_management.dtype_size(dtype)
|
||||
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
|
||||
self.downscale_index_formula = (4, 8, 8)
|
||||
|
||||
@ -154,7 +154,8 @@ class TAEHV(nn.Module):
|
||||
self._show_progress_bar = value
|
||||
|
||||
def encode(self, x, **kwargs):
|
||||
if self.patch_size > 1: x = F.pixel_unshuffle(x, self.patch_size)
|
||||
if self.patch_size > 1:
|
||||
x = F.pixel_unshuffle(x, self.patch_size)
|
||||
x = x.movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
|
||||
if x.shape[1] % 4 != 0:
|
||||
# pad at end to multiple of 4
|
||||
@ -167,5 +168,6 @@ class TAEHV(nn.Module):
|
||||
def decode(self, x, **kwargs):
|
||||
x = self.process_in(x).movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
|
||||
x = apply_model_with_memblocks(self.decoder, x, self.parallel, self.show_progress_bar)
|
||||
if self.patch_size > 1: x = F.pixel_shuffle(x, self.patch_size)
|
||||
if self.patch_size > 1:
|
||||
x = F.pixel_shuffle(x, self.patch_size)
|
||||
return x[:, self.frames_to_trim:].movedim(2, 1)
|
||||
|
||||
@ -8,7 +8,6 @@ from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
import comfy.model_management
|
||||
import comfy.ldm.common_dit
|
||||
|
||||
import comfy.model_management
|
||||
from . import qwen_vl
|
||||
|
||||
@dataclass
|
||||
|
||||
@ -1230,6 +1230,8 @@ def convert_old_quants(state_dict, model_prefix="", metadata={}):
|
||||
out_sd = {}
|
||||
layers = {}
|
||||
for k in list(state_dict.keys()):
|
||||
if k == scaled_fp8_key:
|
||||
continue
|
||||
if not k.startswith(model_prefix):
|
||||
out_sd[k] = state_dict[k]
|
||||
continue
|
||||
|
||||
@ -807,6 +807,7 @@ class OmniProTextToVideoNode(IO.ComfyNode):
|
||||
),
|
||||
IO.Combo.Input("aspect_ratio", options=["16:9", "9:16", "1:1"]),
|
||||
IO.Combo.Input("duration", options=[5, 10]),
|
||||
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
@ -826,6 +827,7 @@ class OmniProTextToVideoNode(IO.ComfyNode):
|
||||
prompt: str,
|
||||
aspect_ratio: str,
|
||||
duration: int,
|
||||
resolution: str = "1080p",
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, min_length=1, max_length=2500)
|
||||
response = await sync_op(
|
||||
@ -837,6 +839,7 @@ class OmniProTextToVideoNode(IO.ComfyNode):
|
||||
prompt=prompt,
|
||||
aspect_ratio=aspect_ratio,
|
||||
duration=str(duration),
|
||||
mode="pro" if resolution == "1080p" else "std",
|
||||
),
|
||||
)
|
||||
return await finish_omni_video_task(cls, response)
|
||||
@ -872,6 +875,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
|
||||
optional=True,
|
||||
tooltip="Up to 6 additional reference images.",
|
||||
),
|
||||
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
@ -893,6 +897,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
|
||||
first_frame: Input.Image,
|
||||
end_frame: Input.Image | None = None,
|
||||
reference_images: Input.Image | None = None,
|
||||
resolution: str = "1080p",
|
||||
) -> IO.NodeOutput:
|
||||
prompt = normalize_omni_prompt_references(prompt)
|
||||
validate_string(prompt, min_length=1, max_length=2500)
|
||||
@ -936,6 +941,7 @@ class OmniProFirstLastFrameNode(IO.ComfyNode):
|
||||
prompt=prompt,
|
||||
duration=str(duration),
|
||||
image_list=image_list,
|
||||
mode="pro" if resolution == "1080p" else "std",
|
||||
),
|
||||
)
|
||||
return await finish_omni_video_task(cls, response)
|
||||
@ -964,6 +970,7 @@ class OmniProImageToVideoNode(IO.ComfyNode):
|
||||
"reference_images",
|
||||
tooltip="Up to 7 reference images.",
|
||||
),
|
||||
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
@ -984,6 +991,7 @@ class OmniProImageToVideoNode(IO.ComfyNode):
|
||||
aspect_ratio: str,
|
||||
duration: int,
|
||||
reference_images: Input.Image,
|
||||
resolution: str = "1080p",
|
||||
) -> IO.NodeOutput:
|
||||
prompt = normalize_omni_prompt_references(prompt)
|
||||
validate_string(prompt, min_length=1, max_length=2500)
|
||||
@ -1005,6 +1013,7 @@ class OmniProImageToVideoNode(IO.ComfyNode):
|
||||
aspect_ratio=aspect_ratio,
|
||||
duration=str(duration),
|
||||
image_list=image_list,
|
||||
mode="pro" if resolution == "1080p" else "std",
|
||||
),
|
||||
)
|
||||
return await finish_omni_video_task(cls, response)
|
||||
@ -1036,6 +1045,7 @@ class OmniProVideoToVideoNode(IO.ComfyNode):
|
||||
tooltip="Up to 4 additional reference images.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
@ -1058,6 +1068,7 @@ class OmniProVideoToVideoNode(IO.ComfyNode):
|
||||
reference_video: Input.Video,
|
||||
keep_original_sound: bool,
|
||||
reference_images: Input.Image | None = None,
|
||||
resolution: str = "1080p",
|
||||
) -> IO.NodeOutput:
|
||||
prompt = normalize_omni_prompt_references(prompt)
|
||||
validate_string(prompt, min_length=1, max_length=2500)
|
||||
@ -1090,6 +1101,7 @@ class OmniProVideoToVideoNode(IO.ComfyNode):
|
||||
duration=str(duration),
|
||||
image_list=image_list if image_list else None,
|
||||
video_list=video_list,
|
||||
mode="pro" if resolution == "1080p" else "std",
|
||||
),
|
||||
)
|
||||
return await finish_omni_video_task(cls, response)
|
||||
@ -1119,6 +1131,7 @@ class OmniProEditVideoNode(IO.ComfyNode):
|
||||
tooltip="Up to 4 additional reference images.",
|
||||
optional=True,
|
||||
),
|
||||
IO.Combo.Input("resolution", options=["1080p", "720p"], optional=True),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
@ -1139,6 +1152,7 @@ class OmniProEditVideoNode(IO.ComfyNode):
|
||||
video: Input.Video,
|
||||
keep_original_sound: bool,
|
||||
reference_images: Input.Image | None = None,
|
||||
resolution: str = "1080p",
|
||||
) -> IO.NodeOutput:
|
||||
prompt = normalize_omni_prompt_references(prompt)
|
||||
validate_string(prompt, min_length=1, max_length=2500)
|
||||
@ -1171,6 +1185,7 @@ class OmniProEditVideoNode(IO.ComfyNode):
|
||||
duration=None,
|
||||
image_list=image_list if image_list else None,
|
||||
video_list=video_list,
|
||||
mode="pro" if resolution == "1080p" else "std",
|
||||
),
|
||||
)
|
||||
return await finish_omni_video_task(cls, response)
|
||||
|
||||
@ -155,7 +155,7 @@ class TripoTextToModelNode(IO.ComfyNode):
|
||||
model_seed=model_seed,
|
||||
texture_seed=texture_seed,
|
||||
texture_quality=texture_quality,
|
||||
face_limit=face_limit,
|
||||
face_limit=face_limit if face_limit != -1 else None,
|
||||
geometry_quality=geometry_quality,
|
||||
auto_size=True,
|
||||
quad=quad,
|
||||
@ -255,7 +255,7 @@ class TripoImageToModelNode(IO.ComfyNode):
|
||||
texture_alignment=texture_alignment,
|
||||
texture_seed=texture_seed,
|
||||
texture_quality=texture_quality,
|
||||
face_limit=face_limit,
|
||||
face_limit=face_limit if face_limit != -1 else None,
|
||||
auto_size=True,
|
||||
quad=quad,
|
||||
),
|
||||
@ -369,7 +369,7 @@ class TripoMultiviewToModelNode(IO.ComfyNode):
|
||||
texture_quality=texture_quality,
|
||||
geometry_quality=geometry_quality,
|
||||
texture_alignment=texture_alignment,
|
||||
face_limit=face_limit,
|
||||
face_limit=face_limit if face_limit != -1 else None,
|
||||
quad=quad,
|
||||
),
|
||||
)
|
||||
|
||||
@ -207,15 +207,15 @@ class ExecutionList(TopologicalSort):
|
||||
return self.output_cache.get(node_id) is not None
|
||||
|
||||
def cache_link(self, from_node_id, to_node_id):
|
||||
if not to_node_id in self.execution_cache:
|
||||
if to_node_id not in self.execution_cache:
|
||||
self.execution_cache[to_node_id] = {}
|
||||
self.execution_cache[to_node_id][from_node_id] = self.output_cache.get(from_node_id)
|
||||
if not from_node_id in self.execution_cache_listeners:
|
||||
if from_node_id not in self.execution_cache_listeners:
|
||||
self.execution_cache_listeners[from_node_id] = set()
|
||||
self.execution_cache_listeners[from_node_id].add(to_node_id)
|
||||
|
||||
def get_cache(self, from_node_id, to_node_id):
|
||||
if not to_node_id in self.execution_cache:
|
||||
if to_node_id not in self.execution_cache:
|
||||
return None
|
||||
value = self.execution_cache[to_node_id].get(from_node_id)
|
||||
if value is None:
|
||||
|
||||
@ -55,7 +55,8 @@ class APG(io.ComfyNode):
|
||||
def pre_cfg_function(args):
|
||||
nonlocal running_avg, prev_sigma
|
||||
|
||||
if len(args["conds_out"]) == 1: return args["conds_out"]
|
||||
if len(args["conds_out"]) == 1:
|
||||
return args["conds_out"]
|
||||
|
||||
cond = args["conds_out"][0]
|
||||
uncond = args["conds_out"][1]
|
||||
|
||||
@ -10,7 +10,7 @@ class Mahiro(io.ComfyNode):
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="Mahiro",
|
||||
display_name="Mahiro is so cute that she deserves a better guidance function!! (。・ω・。)",
|
||||
display_name="Mahiro CFG",
|
||||
category="_for_testing",
|
||||
description="Modify the guidance to scale more on the 'direction' of the positive prompt rather than the difference between the negative prompt.",
|
||||
inputs=[
|
||||
|
||||
@ -78,8 +78,10 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora
|
||||
else:
|
||||
out = vae_model.decode_(t_chunk)
|
||||
|
||||
if isinstance(out, (tuple, list)): out = out[0]
|
||||
if out.ndim == 4: out = out.unsqueeze(2)
|
||||
if isinstance(out, (tuple, list)):
|
||||
out = out[0]
|
||||
if out.ndim == 4:
|
||||
out = out.unsqueeze(2)
|
||||
|
||||
if pad_amount > 0:
|
||||
if encode:
|
||||
@ -136,13 +138,17 @@ def tiled_vae(x, vae_model, tile_size=(512, 512), tile_overlap=(64, 64), tempora
|
||||
|
||||
if cur_ov_h > 0:
|
||||
r = get_ramp(cur_ov_h)
|
||||
if y_idx > 0: w_h[:cur_ov_h] = r
|
||||
if y_end < h: w_h[-cur_ov_h:] = 1.0 - r
|
||||
if y_idx > 0:
|
||||
w_h[:cur_ov_h] = r
|
||||
if y_end < h:
|
||||
w_h[-cur_ov_h:] = 1.0 - r
|
||||
|
||||
if cur_ov_w > 0:
|
||||
r = get_ramp(cur_ov_w)
|
||||
if x_idx > 0: w_w[:cur_ov_w] = r
|
||||
if x_end < w: w_w[-cur_ov_w:] = 1.0 - r
|
||||
if x_idx > 0:
|
||||
w_w[:cur_ov_w] = r
|
||||
if x_end < w:
|
||||
w_w[-cur_ov_w:] = 1.0 - r
|
||||
|
||||
final_weight = w_h.view(1,1,1,-1,1) * w_w.view(1,1,1,1,-1)
|
||||
|
||||
@ -335,7 +341,8 @@ class SeedVR2InputProcessing(io.ComfyNode):
|
||||
|
||||
comfy.model_management.load_models_gpu([vae.patcher])
|
||||
vae_model = vae.first_stage_model
|
||||
scale = 0.9152; shift = 0
|
||||
scale = 0.9152
|
||||
shift = 0
|
||||
if images.dim() != 5: # add the t dim
|
||||
images = images.unsqueeze(0)
|
||||
images = images.permute(0, 1, 4, 2, 3)
|
||||
|
||||
@ -817,7 +817,7 @@ def get_sample_indices(original_fps,
|
||||
if required_duration > total_frames / original_fps:
|
||||
raise ValueError("required_duration must be less than video length")
|
||||
|
||||
if not fixed_start is None and fixed_start >= 0:
|
||||
if fixed_start is not None and fixed_start >= 0:
|
||||
start_frame = fixed_start
|
||||
else:
|
||||
max_start = total_frames - required_origin_frames
|
||||
|
||||
@ -601,6 +601,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
|
||||
if isinstance(ex, comfy.model_management.OOM_EXCEPTION):
|
||||
tips = "This error means you ran out of memory on your GPU.\n\nTIPS: If the workflow worked before you might have accidentally set the batch_size to a large number."
|
||||
logging.info("Memory summary: {}".format(comfy.model_management.debug_memory_summary()))
|
||||
logging.error("Got an OOM, unloading all loaded models.")
|
||||
comfy.model_management.unload_all_models()
|
||||
|
||||
|
||||
13
nodes.py
13
nodes.py
@ -1663,8 +1663,6 @@ class LoadImage:
|
||||
output_masks = []
|
||||
w, h = None, None
|
||||
|
||||
excluded_formats = ['MPO']
|
||||
|
||||
for i in ImageSequence.Iterator(img):
|
||||
i = node_helpers.pillow(ImageOps.exif_transpose, i)
|
||||
|
||||
@ -1692,7 +1690,10 @@ class LoadImage:
|
||||
output_images.append(image)
|
||||
output_masks.append(mask.unsqueeze(0))
|
||||
|
||||
if len(output_images) > 1 and img.format not in excluded_formats:
|
||||
if img.format == "MPO":
|
||||
break # ignore all frames except the first one for MPO format
|
||||
|
||||
if len(output_images) > 1:
|
||||
output_image = torch.cat(output_images, dim=0)
|
||||
output_mask = torch.cat(output_masks, dim=0)
|
||||
else:
|
||||
@ -2242,8 +2243,10 @@ async def init_external_custom_nodes():
|
||||
|
||||
for possible_module in possible_modules:
|
||||
module_path = os.path.join(custom_node_path, possible_module)
|
||||
if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py": continue
|
||||
if module_path.endswith(".disabled"): continue
|
||||
if os.path.isfile(module_path) and os.path.splitext(module_path)[1] != ".py":
|
||||
continue
|
||||
if module_path.endswith(".disabled"):
|
||||
continue
|
||||
if args.disable_all_custom_nodes and possible_module not in args.whitelist_custom_nodes:
|
||||
logging.info(f"Skipping {possible_module} due to disable_all_custom_nodes and whitelist_custom_nodes")
|
||||
continue
|
||||
|
||||
@ -15,12 +15,16 @@ lint.select = [
|
||||
"N805", # invalid-first-argument-name-for-method
|
||||
"S307", # suspicious-eval-usage
|
||||
"S102", # exec
|
||||
"E",
|
||||
"T", # print-usage
|
||||
"W",
|
||||
# The "F" series in Ruff stands for "Pyflakes" rules, which catch various Python syntax errors and undefined names.
|
||||
# See all rules here: https://docs.astral.sh/ruff/rules/#pyflakes-f
|
||||
"F",
|
||||
]
|
||||
|
||||
lint.ignore = ["E501", "E722", "E731", "E712", "E402", "E741"]
|
||||
|
||||
exclude = ["*.ipynb", "**/generated/*.pyi"]
|
||||
|
||||
[tool.pylint]
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
comfyui-frontend-package==1.35.9
|
||||
comfyui-workflow-templates==0.7.64
|
||||
comfyui-workflow-templates==0.7.65
|
||||
comfyui-embedded-docs==0.3.1
|
||||
torch
|
||||
torchsde
|
||||
|
||||
@ -324,7 +324,7 @@ class PromptServer():
|
||||
@routes.get("/models/{folder}")
|
||||
async def get_models(request):
|
||||
folder = request.match_info.get("folder", None)
|
||||
if not folder in folder_paths.folder_names_and_paths:
|
||||
if folder not in folder_paths.folder_names_and_paths:
|
||||
return web.Response(status=404)
|
||||
files = folder_paths.get_filename_list(folder)
|
||||
return web.json_response(files)
|
||||
@ -579,7 +579,7 @@ class PromptServer():
|
||||
folder_name = request.match_info.get("folder_name", None)
|
||||
if folder_name is None:
|
||||
return web.Response(status=404)
|
||||
if not "filename" in request.rel_url.query:
|
||||
if "filename" not in request.rel_url.query:
|
||||
return web.Response(status=404)
|
||||
|
||||
filename = request.rel_url.query["filename"]
|
||||
@ -593,7 +593,7 @@ class PromptServer():
|
||||
if out is None:
|
||||
return web.Response(status=404)
|
||||
dt = json.loads(out)
|
||||
if not "__metadata__" in dt:
|
||||
if "__metadata__" not in dt:
|
||||
return web.Response(status=404)
|
||||
return web.json_response(dt["__metadata__"])
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user