mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 23:00:51 +08:00
Merge branch 'master' of github.com:comfyanonymous/ComfyUI
This commit is contained in:
commit
3d98440fb7
@ -7,7 +7,7 @@ on:
|
||||
description: 'cuda version'
|
||||
required: true
|
||||
type: string
|
||||
default: "121"
|
||||
default: "124"
|
||||
|
||||
python_minor:
|
||||
description: 'python minor version'
|
||||
@ -19,7 +19,7 @@ on:
|
||||
description: 'python patch version'
|
||||
required: true
|
||||
type: string
|
||||
default: "2"
|
||||
default: "3"
|
||||
# push:
|
||||
# branches:
|
||||
# - master
|
||||
@ -49,7 +49,7 @@ jobs:
|
||||
echo 'import site' >> ./python3${{ inputs.python_minor }}._pth
|
||||
curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
|
||||
./python.exe get-pip.py
|
||||
python -m pip wheel torch torchvision mpmath==1.3.0 --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir
|
||||
python -m pip wheel torch torchvision mpmath==1.3.0 numpy==1.26.4 --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu${{ inputs.cu }} -r ../ComfyUI/requirements.txt pygit2 -w ../temp_wheel_dir
|
||||
ls ../temp_wheel_dir
|
||||
./python.exe -s -m pip install --pre ../temp_wheel_dir/*
|
||||
sed -i '1i../ComfyUI' ./python3${{ inputs.python_minor }}._pth
|
||||
|
||||
@ -76,9 +76,6 @@ def create_parser() -> argparse.ArgumentParser:
|
||||
help="Enable cudaMallocAsync (enabled by default for torch 2.0 and up).")
|
||||
cm_group.add_argument("--disable-cuda-malloc", action="store_true", help="Disable cudaMallocAsync.")
|
||||
|
||||
parser.add_argument("--dont-upcast-attention", action="store_true",
|
||||
help="Disable upcasting of attention. Can boost speed but increase the chances of black images.")
|
||||
|
||||
fp_group = parser.add_mutually_exclusive_group()
|
||||
fp_group.add_argument("--force-fp32", action="store_true",
|
||||
help="Force fp32 (If this makes your GPU work better please report it).")
|
||||
@ -125,6 +122,9 @@ def create_parser() -> argparse.ArgumentParser:
|
||||
|
||||
parser.add_argument("--disable-xformers", action="store_true", help="Disable xformers.")
|
||||
|
||||
upcast = parser.add_mutually_exclusive_group()
|
||||
upcast.add_argument("--force-upcast-attention", action="store_true", help="Force enable attention upcasting, please report if it fixes black images.")
|
||||
upcast.add_argument("--dont-upcast-attention", action="store_true", help="Disable all upcasting of attention. Should be unnecessary except for debugging.")
|
||||
vram_group = parser.add_mutually_exclusive_group()
|
||||
vram_group.add_argument("--gpu-only", action="store_true",
|
||||
help="Store and run everything (text encoders/CLIP models, etc... on the GPU).")
|
||||
|
||||
@ -37,6 +37,7 @@ class Configuration(dict):
|
||||
cuda_malloc (bool): Enable cudaMallocAsync. Defaults to True in applicable setups.
|
||||
disable_cuda_malloc (bool): Disable cudaMallocAsync.
|
||||
dont_upcast_attention (bool): Disable upcasting of attention.
|
||||
force_upcast_attention (bool): Force upcasting of attention.
|
||||
force_fp32 (bool): Force using FP32 precision.
|
||||
force_fp16 (bool): Force using FP16 precision.
|
||||
bf16_unet (bool): Use BF16 precision for UNet.
|
||||
@ -106,6 +107,7 @@ class Configuration(dict):
|
||||
self.cuda_malloc: bool = True
|
||||
self.disable_cuda_malloc: bool = False
|
||||
self.dont_upcast_attention: bool = False
|
||||
self.force_upcast_attention: bool = False
|
||||
self.force_fp32: bool = False
|
||||
self.force_fp16: bool = False
|
||||
self.bf16_unet: bool = False
|
||||
|
||||
@ -23,6 +23,15 @@ class TransformersManagedModel(ModelManageable):
|
||||
if model.device != self.offload_device:
|
||||
model.to(device=self.offload_device)
|
||||
|
||||
@property
|
||||
def lowvram_patch_counter(self):
|
||||
return 0
|
||||
|
||||
@lowvram_patch_counter.setter
|
||||
def lowvram_patch_counter(self, value: int):
|
||||
warnings.warn("Not supported")
|
||||
pass
|
||||
|
||||
load_device: torch.device
|
||||
offload_device: torch.device
|
||||
model: PreTrainedModel
|
||||
@ -57,7 +66,7 @@ class TransformersManagedModel(ModelManageable):
|
||||
def model_dtype(self) -> torch.dtype:
|
||||
return self.model.dtype
|
||||
|
||||
def patch_model_lowvram(self, device_to: torch.device, lowvram_model_memory: int) -> torch.nn.Module:
|
||||
def patch_model_lowvram(self, device_to: torch.device, lowvram_model_memory: int, force_patch_weights=False) -> torch.nn.Module:
|
||||
warnings.warn("Transformers models do not currently support adapters like LoRAs")
|
||||
return self.model.to(device=device_to)
|
||||
|
||||
|
||||
@ -18,13 +18,13 @@ from ...cli_args import args
|
||||
from ... import ops
|
||||
ops = ops.disable_weight_init
|
||||
|
||||
# CrossAttn precision handling
|
||||
if args.dont_upcast_attention:
|
||||
logging.info("disabling upcasting of attention")
|
||||
_ATTN_PRECISION = "fp16"
|
||||
else:
|
||||
_ATTN_PRECISION = "fp32"
|
||||
|
||||
def get_attn_precision(attn_precision):
|
||||
if args.dont_upcast_attention:
|
||||
return None
|
||||
if attn_precision is None and args.force_upcast_attention:
|
||||
return torch.float32
|
||||
return attn_precision
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
@ -84,7 +84,9 @@ class FeedForward(nn.Module):
|
||||
def Normalize(in_channels, dtype=None, device=None):
|
||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
|
||||
|
||||
def attention_basic(q, k, v, heads, mask=None):
|
||||
def attention_basic(q, k, v, heads, mask=None, attn_precision=None):
|
||||
attn_precision = get_attn_precision(attn_precision)
|
||||
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
scale = dim_head ** -0.5
|
||||
@ -100,7 +102,7 @@ def attention_basic(q, k, v, heads, mask=None):
|
||||
)
|
||||
|
||||
# force cast to fp32 to avoid overflowing
|
||||
if _ATTN_PRECISION =="fp32":
|
||||
if attn_precision == torch.float32:
|
||||
sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale
|
||||
else:
|
||||
sim = einsum('b i d, b j d -> b i j', q, k) * scale
|
||||
@ -134,7 +136,9 @@ def attention_basic(q, k, v, heads, mask=None):
|
||||
return out
|
||||
|
||||
|
||||
def attention_sub_quad(query, key, value, heads, mask=None):
|
||||
def attention_sub_quad(query, key, value, heads, mask=None, attn_precision=None):
|
||||
attn_precision = get_attn_precision(attn_precision)
|
||||
|
||||
b, _, dim_head = query.shape
|
||||
dim_head //= heads
|
||||
|
||||
@ -145,7 +149,7 @@ def attention_sub_quad(query, key, value, heads, mask=None):
|
||||
key = key.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 3, 1).reshape(b * heads, dim_head, -1)
|
||||
|
||||
dtype = query.dtype
|
||||
upcast_attention = _ATTN_PRECISION =="fp32" and query.dtype != torch.float32
|
||||
upcast_attention = attn_precision == torch.float32 and query.dtype != torch.float32
|
||||
if upcast_attention:
|
||||
bytes_per_token = torch.finfo(torch.float32).bits//8
|
||||
else:
|
||||
@ -194,7 +198,9 @@ def attention_sub_quad(query, key, value, heads, mask=None):
|
||||
hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2)
|
||||
return hidden_states
|
||||
|
||||
def attention_split(q, k, v, heads, mask=None):
|
||||
def attention_split(q, k, v, heads, mask=None, attn_precision=None):
|
||||
attn_precision = get_attn_precision(attn_precision)
|
||||
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
scale = dim_head ** -0.5
|
||||
@ -213,10 +219,12 @@ def attention_split(q, k, v, heads, mask=None):
|
||||
|
||||
mem_free_total = model_management.get_free_memory(q.device)
|
||||
|
||||
if _ATTN_PRECISION =="fp32":
|
||||
if attn_precision == torch.float32:
|
||||
element_size = 4
|
||||
upcast = True
|
||||
else:
|
||||
element_size = q.element_size()
|
||||
upcast = False
|
||||
|
||||
gb = 1024 ** 3
|
||||
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * element_size
|
||||
@ -250,7 +258,7 @@ def attention_split(q, k, v, heads, mask=None):
|
||||
slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
|
||||
for i in range(0, q.shape[1], slice_size):
|
||||
end = i + slice_size
|
||||
if _ATTN_PRECISION =="fp32":
|
||||
if upcast:
|
||||
with torch.autocast(enabled=False, device_type = 'cuda'):
|
||||
s1 = einsum('b i d, b j d -> b i j', q[:, i:end].float(), k.float()) * scale
|
||||
else:
|
||||
@ -301,7 +309,7 @@ try:
|
||||
except:
|
||||
pass
|
||||
|
||||
def attention_xformers(q, k, v, heads, mask=None):
|
||||
def attention_xformers(q, k, v, heads, mask=None, attn_precision=None):
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
if BROKEN_XFORMERS:
|
||||
@ -333,7 +341,7 @@ def attention_xformers(q, k, v, heads, mask=None):
|
||||
)
|
||||
return out
|
||||
|
||||
def attention_pytorch(q, k, v, heads, mask=None):
|
||||
def attention_pytorch(q, k, v, heads, mask=None, attn_precision=None):
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
q, k, v = map(
|
||||
@ -383,10 +391,11 @@ def optimized_attention_for_device(device, mask=False, small_input=False):
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=ops):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=ops):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = default(context_dim, query_dim)
|
||||
self.attn_precision = attn_precision
|
||||
|
||||
self.heads = heads
|
||||
self.dim_head = dim_head
|
||||
@ -408,15 +417,15 @@ class CrossAttention(nn.Module):
|
||||
v = self.to_v(context)
|
||||
|
||||
if mask is None:
|
||||
out = optimized_attention(q, k, v, self.heads)
|
||||
out = optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision)
|
||||
else:
|
||||
out = optimized_attention_masked(q, k, v, self.heads, mask)
|
||||
out = optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision)
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, ff_in=False, inner_dim=None,
|
||||
disable_self_attn=False, disable_temporal_crossattention=False, switch_temporal_ca_to_sa=False, dtype=None, device=None, operations=ops):
|
||||
disable_self_attn=False, disable_temporal_crossattention=False, switch_temporal_ca_to_sa=False, attn_precision=None, dtype=None, device=None, operations=ops):
|
||||
super().__init__()
|
||||
|
||||
self.ff_in = ff_in or inner_dim is not None
|
||||
@ -424,6 +433,7 @@ class BasicTransformerBlock(nn.Module):
|
||||
inner_dim = dim
|
||||
|
||||
self.is_res = inner_dim == dim
|
||||
self.attn_precision = attn_precision
|
||||
|
||||
if self.ff_in:
|
||||
self.norm_in = operations.LayerNorm(dim, dtype=dtype, device=device)
|
||||
@ -431,7 +441,7 @@ class BasicTransformerBlock(nn.Module):
|
||||
|
||||
self.disable_self_attn = disable_self_attn
|
||||
self.attn1 = CrossAttention(query_dim=inner_dim, heads=n_heads, dim_head=d_head, dropout=dropout,
|
||||
context_dim=context_dim if self.disable_self_attn else None, dtype=dtype, device=device, operations=operations) # is a self-attention if not self.disable_self_attn
|
||||
context_dim=context_dim if self.disable_self_attn else None, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations) # is a self-attention if not self.disable_self_attn
|
||||
self.ff = FeedForward(inner_dim, dim_out=dim, dropout=dropout, glu=gated_ff, dtype=dtype, device=device, operations=operations)
|
||||
|
||||
if disable_temporal_crossattention:
|
||||
@ -445,7 +455,7 @@ class BasicTransformerBlock(nn.Module):
|
||||
context_dim_attn2 = context_dim
|
||||
|
||||
self.attn2 = CrossAttention(query_dim=inner_dim, context_dim=context_dim_attn2,
|
||||
heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device, operations=operations) # is self-attn if context is none
|
||||
heads=n_heads, dim_head=d_head, dropout=dropout, attn_precision=self.attn_precision, dtype=dtype, device=device, operations=operations) # is self-attn if context is none
|
||||
self.norm2 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
|
||||
|
||||
self.norm1 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
|
||||
@ -475,6 +485,7 @@ class BasicTransformerBlock(nn.Module):
|
||||
|
||||
extra_options["n_heads"] = self.n_heads
|
||||
extra_options["dim_head"] = self.d_head
|
||||
extra_options["attn_precision"] = self.attn_precision
|
||||
|
||||
if self.ff_in:
|
||||
x_skip = x
|
||||
@ -585,7 +596,7 @@ class SpatialTransformer(nn.Module):
|
||||
def __init__(self, in_channels, n_heads, d_head,
|
||||
depth=1, dropout=0., context_dim=None,
|
||||
disable_self_attn=False, use_linear=False,
|
||||
use_checkpoint=True, dtype=None, device=None, operations=ops):
|
||||
use_checkpoint=True, attn_precision=None, dtype=None, device=None, operations=ops):
|
||||
super().__init__()
|
||||
if exists(context_dim) and not isinstance(context_dim, list):
|
||||
context_dim = [context_dim] * depth
|
||||
@ -603,7 +614,7 @@ class SpatialTransformer(nn.Module):
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim[d],
|
||||
disable_self_attn=disable_self_attn, checkpoint=use_checkpoint, dtype=dtype, device=device, operations=operations)
|
||||
disable_self_attn=disable_self_attn, checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=dtype, device=device, operations=operations)
|
||||
for d in range(depth)]
|
||||
)
|
||||
if not use_linear:
|
||||
@ -659,6 +670,7 @@ class SpatialVideoTransformer(SpatialTransformer):
|
||||
disable_self_attn=False,
|
||||
disable_temporal_crossattention=False,
|
||||
max_time_embed_period: int = 10000,
|
||||
attn_precision=None,
|
||||
dtype=None, device=None, operations=ops
|
||||
):
|
||||
super().__init__(
|
||||
@ -671,6 +683,7 @@ class SpatialVideoTransformer(SpatialTransformer):
|
||||
context_dim=context_dim,
|
||||
use_linear=use_linear,
|
||||
disable_self_attn=disable_self_attn,
|
||||
attn_precision=attn_precision,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
self.time_depth = time_depth
|
||||
@ -700,6 +713,7 @@ class SpatialVideoTransformer(SpatialTransformer):
|
||||
inner_dim=time_mix_inner_dim,
|
||||
disable_self_attn=disable_self_attn,
|
||||
disable_temporal_crossattention=disable_temporal_crossattention,
|
||||
attn_precision=attn_precision,
|
||||
dtype=dtype, device=device, operations=operations
|
||||
)
|
||||
for _ in range(self.depth)
|
||||
|
||||
@ -431,6 +431,7 @@ class UNetModel(nn.Module):
|
||||
video_kernel_size=None,
|
||||
disable_temporal_crossattention=False,
|
||||
max_ddpm_temb_period=10000,
|
||||
attn_precision=None,
|
||||
device=None,
|
||||
operations=ops,
|
||||
):
|
||||
@ -550,13 +551,14 @@ class UNetModel(nn.Module):
|
||||
disable_self_attn=disable_self_attn,
|
||||
disable_temporal_crossattention=disable_temporal_crossattention,
|
||||
max_time_embed_period=max_ddpm_temb_period,
|
||||
attn_precision=attn_precision,
|
||||
dtype=self.dtype, device=device, operations=operations
|
||||
)
|
||||
else:
|
||||
return SpatialTransformer(
|
||||
ch, num_heads, dim_head, depth=depth, context_dim=context_dim,
|
||||
disable_self_attn=disable_self_attn, use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint, dtype=self.dtype, device=device, operations=operations
|
||||
use_checkpoint=use_checkpoint, attn_precision=attn_precision, dtype=self.dtype, device=device, operations=operations
|
||||
)
|
||||
|
||||
def get_resblock(
|
||||
|
||||
@ -119,8 +119,8 @@ def get_total_memory(dev=None, torch_total_too=False):
|
||||
elif is_intel_xpu():
|
||||
stats = torch.xpu.memory_stats(dev)
|
||||
mem_reserved = stats['reserved_bytes.all.current']
|
||||
mem_total = torch.xpu.get_device_properties(dev).total_memory
|
||||
mem_total_torch = mem_reserved
|
||||
mem_total = torch.xpu.get_device_properties(dev).total_memory
|
||||
else:
|
||||
stats = torch.cuda.memory_stats(dev)
|
||||
mem_reserved = stats['reserved_bytes.all.current']
|
||||
@ -308,7 +308,7 @@ class LoadedModel:
|
||||
else:
|
||||
return self.model_memory()
|
||||
|
||||
def model_load(self, lowvram_model_memory=0):
|
||||
def model_load(self, lowvram_model_memory=0, force_patch_weights=False):
|
||||
patch_model_to = self.device
|
||||
|
||||
self.model.model_patches_to(self.device)
|
||||
@ -318,7 +318,7 @@ class LoadedModel:
|
||||
|
||||
try:
|
||||
if lowvram_model_memory > 0 and load_weights:
|
||||
self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory)
|
||||
self.real_model = self.model.patch_model_lowvram(device_to=patch_model_to, lowvram_model_memory=lowvram_model_memory, force_patch_weights=force_patch_weights)
|
||||
else:
|
||||
self.real_model = self.model.patch_model(device_to=patch_model_to, patch_weights=load_weights)
|
||||
except Exception as e:
|
||||
@ -332,6 +332,11 @@ class LoadedModel:
|
||||
self.weights_loaded = True
|
||||
return self.real_model
|
||||
|
||||
def should_reload_model(self, force_patch_weights=False):
|
||||
if force_patch_weights and self.model.lowvram_patch_counter > 0:
|
||||
return True
|
||||
return False
|
||||
|
||||
def model_unload(self, unpatch_weights=True):
|
||||
self.model.unpatch_model(self.model.offload_device, unpatch_weights=unpatch_weights)
|
||||
self.model.model_patches_to(self.model.offload_device)
|
||||
@ -408,7 +413,7 @@ def free_memory(memory_required, device, keep_loaded=[]):
|
||||
soft_empty_cache()
|
||||
|
||||
|
||||
def load_models_gpu(models, memory_required=0):
|
||||
def load_models_gpu(models, memory_required=0, force_patch_weights=False):
|
||||
global vram_state
|
||||
|
||||
with model_management_lock:
|
||||
@ -420,12 +425,21 @@ def load_models_gpu(models, memory_required=0):
|
||||
models_already_loaded = []
|
||||
for x in models:
|
||||
loaded_model = LoadedModel(x)
|
||||
loaded = None
|
||||
|
||||
if loaded_model in current_loaded_models:
|
||||
index = current_loaded_models.index(loaded_model)
|
||||
current_loaded_models.insert(0, current_loaded_models.pop(index))
|
||||
models_already_loaded.append(loaded_model)
|
||||
else:
|
||||
try:
|
||||
loaded_model_index = current_loaded_models.index(loaded_model)
|
||||
except ValueError:
|
||||
loaded_model_index = None
|
||||
|
||||
if loaded_model_index is not None:
|
||||
loaded = current_loaded_models[loaded_model_index]
|
||||
if loaded.should_reload_model(force_patch_weights=force_patch_weights): # TODO: cleanup this model reload logic
|
||||
current_loaded_models.pop(loaded_model_index).model_unload(unpatch_weights=True)
|
||||
loaded = None
|
||||
else:
|
||||
models_already_loaded.append(loaded)
|
||||
if loaded is None:
|
||||
if hasattr(x, "model"):
|
||||
logging.info(f"Requested to load {x.model.__class__.__name__}")
|
||||
models_to_load.append(loaded_model)
|
||||
@ -473,7 +487,7 @@ def load_models_gpu(models, memory_required=0):
|
||||
if vram_set_state == VRAMState.NO_VRAM:
|
||||
lowvram_model_memory = 64 * 1024 * 1024
|
||||
|
||||
cur_loaded_model = loaded_model.model_load(lowvram_model_memory)
|
||||
loaded_model.model_load(lowvram_model_memory, force_patch_weights=force_patch_weights)
|
||||
current_loaded_models.insert(0, loaded_model)
|
||||
return
|
||||
|
||||
@ -738,10 +752,10 @@ def get_free_memory(dev=None, torch_free_too=False):
|
||||
elif is_intel_xpu():
|
||||
stats = torch.xpu.memory_stats(dev)
|
||||
mem_active = stats['active_bytes.all.current']
|
||||
mem_allocated = stats['allocated_bytes.all.current']
|
||||
mem_reserved = stats['reserved_bytes.all.current']
|
||||
mem_free_torch = mem_reserved - mem_active
|
||||
mem_free_total = torch.xpu.get_device_properties(dev).total_memory - mem_allocated
|
||||
mem_free_xpu = torch.xpu.get_device_properties(dev).total_memory - mem_reserved
|
||||
mem_free_total = mem_free_xpu + mem_free_torch
|
||||
else:
|
||||
stats = torch.cuda.memory_stats(dev)
|
||||
mem_active = stats['active_bytes.all.current']
|
||||
|
||||
@ -38,7 +38,7 @@ class ModelManageable(Protocol):
|
||||
def model_dtype(self) -> torch.dtype:
|
||||
...
|
||||
|
||||
def patch_model_lowvram(self, device_to: torch.device, lowvram_model_memory: int) -> torch.nn.Module:
|
||||
def patch_model_lowvram(self, device_to: torch.device, lowvram_model_memory: int, force_patch_weights: Optional[bool] = False) -> torch.nn.Module:
|
||||
...
|
||||
|
||||
def patch_model(self, device_to: torch.device, patch_weights: bool) -> torch.nn.Module:
|
||||
@ -46,3 +46,7 @@ class ModelManageable(Protocol):
|
||||
|
||||
def unpatch_model(self, offload_device: torch.device, unpatch_weights: Optional[bool] = False) -> torch.nn.Module:
|
||||
...
|
||||
|
||||
@property
|
||||
def lowvram_patch_counter(self) -> int:
|
||||
...
|
||||
@ -19,7 +19,7 @@ def apply_weight_decompose(dora_scale, weight):
|
||||
.transpose(0, 1)
|
||||
)
|
||||
|
||||
return weight * (dora_scale / weight_norm)
|
||||
return weight * (dora_scale / weight_norm).type(weight.dtype)
|
||||
|
||||
|
||||
def set_model_options_patch_replace(model_options, patch, name, block_name, number, transformer_index=None):
|
||||
@ -65,6 +65,15 @@ class ModelPatcher(ModelManageable):
|
||||
self.weight_inplace_update = weight_inplace_update
|
||||
self.model_lowvram = False
|
||||
self.patches_uuid = uuid.uuid4()
|
||||
self._lowvram_patch_counter = 0
|
||||
|
||||
@property
|
||||
def lowvram_patch_counter(self):
|
||||
return self._lowvram_patch_counter
|
||||
|
||||
@lowvram_patch_counter.setter
|
||||
def lowvram_patch_counter(self, value: int):
|
||||
self._lowvram_patch_counter = value
|
||||
|
||||
def model_size(self):
|
||||
if self.size > 0:
|
||||
@ -278,7 +287,7 @@ class ModelPatcher(ModelManageable):
|
||||
|
||||
return self.model
|
||||
|
||||
def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0):
|
||||
def patch_model_lowvram(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False):
|
||||
self.patch_model(device_to, patch_weights=False)
|
||||
|
||||
logging.info("loading in lowvram mode {}".format(lowvram_model_memory / (1024 * 1024)))
|
||||
@ -292,6 +301,7 @@ class ModelPatcher(ModelManageable):
|
||||
return self.model_patcher.calculate_weight(self.model_patcher.patches[self.key], weight, self.key)
|
||||
|
||||
mem_counter = 0
|
||||
patch_counter = 0
|
||||
for n, m in self.model.named_modules():
|
||||
lowvram_weight = False
|
||||
if hasattr(m, "comfy_cast_weights"):
|
||||
@ -304,9 +314,17 @@ class ModelPatcher(ModelManageable):
|
||||
|
||||
if lowvram_weight:
|
||||
if weight_key in self.patches:
|
||||
m.weight_function = LowVramPatch(weight_key, self)
|
||||
if force_patch_weights:
|
||||
self.patch_weight_to_device(weight_key)
|
||||
else:
|
||||
m.weight_function = LowVramPatch(weight_key, self)
|
||||
patch_counter += 1
|
||||
if bias_key in self.patches:
|
||||
m.bias_function = LowVramPatch(bias_key, self)
|
||||
if force_patch_weights:
|
||||
self.patch_weight_to_device(bias_key)
|
||||
else:
|
||||
m.bias_function = LowVramPatch(bias_key, self)
|
||||
patch_counter += 1
|
||||
|
||||
m.prev_comfy_cast_weights = m.comfy_cast_weights
|
||||
m.comfy_cast_weights = True
|
||||
@ -319,6 +337,7 @@ class ModelPatcher(ModelManageable):
|
||||
logging.debug("lowvram: loaded module regularly {}".format(m))
|
||||
|
||||
self.model_lowvram = True
|
||||
self.lowvram_patch_counter = patch_counter
|
||||
return self.model
|
||||
|
||||
def calculate_weight(self, patches, weight, key):
|
||||
@ -470,6 +489,7 @@ class ModelPatcher(ModelManageable):
|
||||
m.bias_function = None
|
||||
|
||||
self.model_lowvram = False
|
||||
self.lowvram_patch_counter = 0
|
||||
|
||||
keys = list(self.backup.keys())
|
||||
|
||||
|
||||
@ -1464,6 +1464,9 @@ class LoadImage:
|
||||
|
||||
output_images = []
|
||||
output_masks = []
|
||||
w, h = None, None
|
||||
|
||||
excluded_formats = ['MPO']
|
||||
|
||||
# maintain the legacy path
|
||||
# this will ultimately return a tensor, so we'd rather have the tensors directly
|
||||
@ -1478,6 +1481,14 @@ class LoadImage:
|
||||
if i.mode == 'I':
|
||||
i = i.point(lambda i: i * (1 / 255))
|
||||
image = i.convert("RGB")
|
||||
|
||||
if len(output_images) == 0:
|
||||
w = image.size[0]
|
||||
h = image.size[1]
|
||||
|
||||
if image.size[0] != w or image.size[1] != h:
|
||||
continue
|
||||
|
||||
image = np.array(image).astype(np.float32) / 255.0
|
||||
image = torch.from_numpy(image)[None,]
|
||||
if 'A' in i.getbands():
|
||||
@ -1488,14 +1499,14 @@ class LoadImage:
|
||||
output_images.append(image)
|
||||
output_masks.append(mask.unsqueeze(0))
|
||||
|
||||
if len(output_images) > 1:
|
||||
if len(output_images) > 1 and img.format not in excluded_formats:
|
||||
output_image = torch.cat(output_images, dim=0)
|
||||
output_mask = torch.cat(output_masks, dim=0)
|
||||
else:
|
||||
output_image = output_images[0]
|
||||
output_mask = output_masks[0]
|
||||
|
||||
return output_image, output_mask
|
||||
return (output_image, output_mask)
|
||||
|
||||
@classmethod
|
||||
def IS_CHANGED(s, image):
|
||||
|
||||
@ -582,7 +582,7 @@ def save_checkpoint(output_path, model, clip=None, vae=None, clip_vision=None, m
|
||||
load_models.append(clip.load_model())
|
||||
clip_sd = clip.get_sd()
|
||||
|
||||
model_management.load_models_gpu(load_models)
|
||||
model_management.load_models_gpu(load_models, force_patch_weights=True)
|
||||
clip_vision_sd = clip_vision.get_sd() if clip_vision is not None else None
|
||||
sd = model.model.state_dict_for_saving(clip_sd, vae.get_sd(), clip_vision_sd)
|
||||
for k in extra_keys:
|
||||
|
||||
@ -65,6 +65,12 @@ class SD20(supported_models_base.BASE):
|
||||
"use_temporal_attention": False,
|
||||
}
|
||||
|
||||
unet_extra_config = {
|
||||
"num_heads": -1,
|
||||
"num_head_channels": 64,
|
||||
"attn_precision": torch.float32,
|
||||
}
|
||||
|
||||
latent_format = latent_formats.SD15
|
||||
|
||||
def model_type(self, state_dict, prefix=""):
|
||||
@ -276,6 +282,12 @@ class SVD_img2vid(supported_models_base.BASE):
|
||||
"use_temporal_resblock": True
|
||||
}
|
||||
|
||||
unet_extra_config = {
|
||||
"num_heads": -1,
|
||||
"num_head_channels": 64,
|
||||
"attn_precision": torch.float32,
|
||||
}
|
||||
|
||||
clip_vision_prefix = "conditioner.embedders.0.open_clip.model.visual."
|
||||
|
||||
latent_format = latent_formats.SD15
|
||||
|
||||
@ -262,6 +262,36 @@ export class ComfyApp {
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
#addRestoreWorkflowView() {
|
||||
const serialize = LGraph.prototype.serialize;
|
||||
const self = this;
|
||||
LGraph.prototype.serialize = function() {
|
||||
const workflow = serialize.apply(this, arguments);
|
||||
|
||||
// Store the drag & scale info in the serialized workflow if the setting is enabled
|
||||
if (self.enableWorkflowViewRestore.value) {
|
||||
if (!workflow.extra) {
|
||||
workflow.extra = {};
|
||||
}
|
||||
workflow.extra.ds = {
|
||||
scale: self.canvas.ds.scale,
|
||||
offset: self.canvas.ds.offset,
|
||||
};
|
||||
} else if (workflow.extra?.ds) {
|
||||
// Clear any old view data
|
||||
delete workflow.extra.ds;
|
||||
}
|
||||
|
||||
return workflow;
|
||||
}
|
||||
this.enableWorkflowViewRestore = this.ui.settings.addSetting({
|
||||
id: "Comfy.EnableWorkflowViewRestore",
|
||||
name: "Save and restore canvas position and zoom level in workflows",
|
||||
type: "boolean",
|
||||
defaultValue: true
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds special context menu handling for nodes
|
||||
@ -1505,6 +1535,7 @@ export class ComfyApp {
|
||||
this.#addProcessKeyHandler();
|
||||
this.#addConfigureHandler();
|
||||
this.#addApiUpdateHandlers();
|
||||
this.#addRestoreWorkflowView();
|
||||
|
||||
this.graph = new LGraph();
|
||||
|
||||
@ -1805,6 +1836,10 @@ export class ComfyApp {
|
||||
|
||||
try {
|
||||
this.graph.configure(graphData);
|
||||
if (this.enableWorkflowViewRestore.value && graphData.extra?.ds) {
|
||||
this.canvas.ds.offset = graphData.extra.ds.offset;
|
||||
this.canvas.ds.scale = graphData.extra.ds.scale;
|
||||
}
|
||||
} catch (error) {
|
||||
let errorHint = [];
|
||||
// Try extracting filename to see if it was caused by an extension script
|
||||
@ -2122,6 +2157,14 @@ export class ComfyApp {
|
||||
api.dispatchEvent(new CustomEvent("promptQueued", { detail: { number, batchCount } }));
|
||||
}
|
||||
|
||||
showErrorOnFileLoad(file) {
|
||||
this.ui.dialog.show(
|
||||
$el("div", [
|
||||
$el("p", {textContent: `Unable to find workflow in ${file.name}`})
|
||||
]).outerHTML
|
||||
);
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads workflow data from the specified file
|
||||
* @param {File} file
|
||||
@ -2129,27 +2172,27 @@ export class ComfyApp {
|
||||
async handleFile(file) {
|
||||
if (file.type === "image/png") {
|
||||
const pngInfo = await getPngMetadata(file);
|
||||
if (pngInfo) {
|
||||
if (pngInfo.workflow) {
|
||||
await this.loadGraphData(JSON.parse(pngInfo.workflow));
|
||||
} else if (pngInfo.prompt) {
|
||||
this.loadApiJson(JSON.parse(pngInfo.prompt));
|
||||
} else if (pngInfo.parameters) {
|
||||
importA1111(this.graph, pngInfo.parameters);
|
||||
}
|
||||
if (pngInfo?.workflow) {
|
||||
await this.loadGraphData(JSON.parse(pngInfo.workflow));
|
||||
} else if (pngInfo?.prompt) {
|
||||
this.loadApiJson(JSON.parse(pngInfo.prompt));
|
||||
} else if (pngInfo?.parameters) {
|
||||
importA1111(this.graph, pngInfo.parameters);
|
||||
} else {
|
||||
this.showErrorOnFileLoad(file);
|
||||
}
|
||||
} else if (file.type === "image/webp") {
|
||||
const pngInfo = await getWebpMetadata(file);
|
||||
if (pngInfo) {
|
||||
if (pngInfo.workflow) {
|
||||
this.loadGraphData(JSON.parse(pngInfo.workflow));
|
||||
} else if (pngInfo.Workflow) {
|
||||
this.loadGraphData(JSON.parse(pngInfo.Workflow)); // Support loading workflows from that webp custom node.
|
||||
} else if (pngInfo.prompt) {
|
||||
this.loadApiJson(JSON.parse(pngInfo.prompt));
|
||||
} else if (pngInfo.Prompt) {
|
||||
this.loadApiJson(JSON.parse(pngInfo.Prompt)); // Support loading prompts from that webp custom node.
|
||||
}
|
||||
// Support loading workflows from that webp custom node.
|
||||
const workflow = pngInfo?.workflow || pngInfo?.Workflow;
|
||||
const prompt = pngInfo?.prompt || pngInfo?.Prompt;
|
||||
|
||||
if (workflow) {
|
||||
this.loadGraphData(JSON.parse(workflow));
|
||||
} else if (prompt) {
|
||||
this.loadApiJson(JSON.parse(prompt));
|
||||
} else {
|
||||
this.showErrorOnFileLoad(file);
|
||||
}
|
||||
} else if (file.type === "application/json" || file.name?.endsWith(".json")) {
|
||||
const reader = new FileReader();
|
||||
@ -2170,7 +2213,11 @@ export class ComfyApp {
|
||||
await this.loadGraphData(JSON.parse(info.workflow));
|
||||
} else if (info.prompt) {
|
||||
this.loadApiJson(JSON.parse(info.prompt));
|
||||
} else {
|
||||
this.showErrorOnFileLoad(file);
|
||||
}
|
||||
} else {
|
||||
this.showErrorOnFileLoad(file);
|
||||
}
|
||||
}
|
||||
|
||||
@ -2278,6 +2325,12 @@ export class ComfyApp {
|
||||
await this.#invokeExtensionsAsync("refreshComboInNodes", defs);
|
||||
}
|
||||
|
||||
resetView() {
|
||||
app.canvas.ds.scale = 1;
|
||||
app.canvas.ds.offset = [0, 0]
|
||||
app.graph.setDirtyCanvas(true, true);
|
||||
}
|
||||
|
||||
/**
|
||||
* Clean current state
|
||||
*/
|
||||
|
||||
@ -597,16 +597,23 @@ export class ComfyUI {
|
||||
if (!confirmClear.value || confirm("Clear workflow?")) {
|
||||
app.clean();
|
||||
app.graph.clear();
|
||||
app.resetView();
|
||||
}
|
||||
}
|
||||
}),
|
||||
$el("button", {
|
||||
id: "comfy-load-default-button", textContent: "Load Default", onclick: async () => {
|
||||
if (!confirmClear.value || confirm("Load default workflow?")) {
|
||||
app.resetView();
|
||||
await app.loadGraphData()
|
||||
}
|
||||
}
|
||||
}),
|
||||
$el("button", {
|
||||
id: "comfy-reset-view-button", textContent: "Reset View", onclick: async () => {
|
||||
app.resetView();
|
||||
}
|
||||
}),
|
||||
]);
|
||||
|
||||
const devMode = this.settings.addSetting({
|
||||
|
||||
@ -174,9 +174,14 @@ def save_checkpoint(model, clip=None, vae=None, clip_vision=None, filename_prefi
|
||||
|
||||
enable_modelspec = True
|
||||
if isinstance(model.model, model_base.SDXL):
|
||||
metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-base"
|
||||
if isinstance(model.model, model_base.SDXL_instructpix2pix):
|
||||
metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-edit"
|
||||
else:
|
||||
metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-base"
|
||||
elif isinstance(model.model, model_base.SDXLRefiner):
|
||||
metadata["modelspec.architecture"] = "stable-diffusion-xl-v1-refiner"
|
||||
elif isinstance(model.model, model_base.SVD_img2vid):
|
||||
metadata["modelspec.architecture"] = "stable-video-diffusion-img2vid-v1"
|
||||
else:
|
||||
enable_modelspec = False
|
||||
|
||||
@ -261,7 +266,7 @@ class CLIPSave:
|
||||
for x in extra_pnginfo:
|
||||
metadata[x] = json.dumps(extra_pnginfo[x])
|
||||
|
||||
model_management.load_models_gpu([clip.load_model()])
|
||||
model_management.load_models_gpu([clip.load_model()], force_patch_weights=True)
|
||||
clip_sd = clip.get_sd()
|
||||
|
||||
for prefix in ["clip_l.", "clip_g.", ""]:
|
||||
|
||||
@ -5,12 +5,12 @@ import math
|
||||
|
||||
from einops import rearrange, repeat
|
||||
import os
|
||||
from comfy.ldm.modules.attention import optimized_attention, _ATTN_PRECISION
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
from comfy import samplers
|
||||
|
||||
# from comfy/ldm/modules/attention.py
|
||||
# but modified to return attention scores as well as output
|
||||
def attention_basic_with_sim(q, k, v, heads, mask=None):
|
||||
def attention_basic_with_sim(q, k, v, heads, mask=None, attn_precision=None):
|
||||
b, _, dim_head = q.shape
|
||||
dim_head //= heads
|
||||
scale = dim_head ** -0.5
|
||||
@ -26,7 +26,7 @@ def attention_basic_with_sim(q, k, v, heads, mask=None):
|
||||
)
|
||||
|
||||
# force cast to fp32 to avoid overflowing
|
||||
if _ATTN_PRECISION =="fp32":
|
||||
if attn_precision == torch.float32:
|
||||
sim = einsum('b i d, b j d -> b i j', q.float(), k.float()) * scale
|
||||
else:
|
||||
sim = einsum('b i d, b j d -> b i j', q, k) * scale
|
||||
@ -121,13 +121,13 @@ class SelfAttentionGuidance:
|
||||
if 1 in cond_or_uncond:
|
||||
uncond_index = cond_or_uncond.index(1)
|
||||
# do the entire attention operation, but save the attention scores to attn_scores
|
||||
(out, sim) = attention_basic_with_sim(q, k, v, heads=heads)
|
||||
(out, sim) = attention_basic_with_sim(q, k, v, heads=heads, attn_precision=extra_options["attn_precision"])
|
||||
# when using a higher batch size, I BELIEVE the result batch dimension is [uc1, ... ucn, c1, ... cn]
|
||||
n_slices = heads * b
|
||||
attn_scores = sim[n_slices * uncond_index:n_slices * (uncond_index+1)]
|
||||
return out
|
||||
else:
|
||||
return optimized_attention(q, k, v, heads=heads)
|
||||
return optimized_attention(q, k, v, heads=heads, attn_precision=extra_options["attn_precision"])
|
||||
|
||||
def post_cfg_function(args):
|
||||
nonlocal attn_scores
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import os
|
||||
import pathlib
|
||||
import re
|
||||
import sys
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
@ -225,6 +226,7 @@ def test_image_exif_merge():
|
||||
|
||||
|
||||
@freeze_time("2024-01-14 03:21:34", tz_offset=-4)
|
||||
@pytest.mark.skipif(sys.platform == 'win32')
|
||||
def test_image_exif_creation_date_and_batch_number():
|
||||
assert ImageExifCreationDateAndBatchNumber.INPUT_TYPES() is not None
|
||||
n = ImageExifCreationDateAndBatchNumber()
|
||||
@ -264,7 +266,7 @@ def test_file_request_parameter(use_temporary_input_directory):
|
||||
image.save(image_path)
|
||||
|
||||
n = ImageRequestParameter()
|
||||
loaded_image, = n.execute(uri=image_path)
|
||||
loaded_image, = n.execute(value=image_path)
|
||||
assert loaded_image.shape == (1, 1, 1, 3)
|
||||
from comfy.nodes.base_nodes import LoadImage
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user