mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-24 21:30:15 +08:00
Merge branch 'comfyanonymous:master' into master
This commit is contained in:
commit
3c90d0ea73
@ -105,6 +105,7 @@ cache_group = parser.add_mutually_exclusive_group()
|
|||||||
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
|
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
|
||||||
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
|
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
|
||||||
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
|
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
|
||||||
|
cache_group.add_argument("--cache-ram", nargs='?', const=4.0, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threhold the cache remove large items to free RAM. Default 4GB")
|
||||||
|
|
||||||
attn_group = parser.add_mutually_exclusive_group()
|
attn_group = parser.add_mutually_exclusive_group()
|
||||||
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
|
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
|
||||||
|
|||||||
@ -588,7 +588,7 @@ class WanModel(torch.nn.Module):
|
|||||||
x = self.unpatchify(x, grid_sizes)
|
x = self.unpatchify(x, grid_sizes)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None):
|
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, transformer_options={}):
|
||||||
patch_size = self.patch_size
|
patch_size = self.patch_size
|
||||||
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
|
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
|
||||||
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
|
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
|
||||||
@ -601,10 +601,22 @@ class WanModel(torch.nn.Module):
|
|||||||
if steps_w is None:
|
if steps_w is None:
|
||||||
steps_w = w_len
|
steps_w = w_len
|
||||||
|
|
||||||
|
h_start = 0
|
||||||
|
w_start = 0
|
||||||
|
rope_options = transformer_options.get("rope_options", None)
|
||||||
|
if rope_options is not None:
|
||||||
|
t_len = (t_len - 1.0) * rope_options.get("scale_t", 1.0) + 1.0
|
||||||
|
h_len = (h_len - 1.0) * rope_options.get("scale_y", 1.0) + 1.0
|
||||||
|
w_len = (w_len - 1.0) * rope_options.get("scale_x", 1.0) + 1.0
|
||||||
|
|
||||||
|
t_start += rope_options.get("shift_t", 0.0)
|
||||||
|
h_start += rope_options.get("shift_y", 0.0)
|
||||||
|
w_start += rope_options.get("shift_x", 0.0)
|
||||||
|
|
||||||
img_ids = torch.zeros((steps_t, steps_h, steps_w, 3), device=device, dtype=dtype)
|
img_ids = torch.zeros((steps_t, steps_h, steps_w, 3), device=device, dtype=dtype)
|
||||||
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(t_start, t_start + (t_len - 1), steps=steps_t, device=device, dtype=dtype).reshape(-1, 1, 1)
|
img_ids[:, :, :, 0] = img_ids[:, :, :, 0] + torch.linspace(t_start, t_start + (t_len - 1), steps=steps_t, device=device, dtype=dtype).reshape(-1, 1, 1)
|
||||||
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(0, h_len - 1, steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1)
|
img_ids[:, :, :, 1] = img_ids[:, :, :, 1] + torch.linspace(h_start, h_start + (h_len - 1), steps=steps_h, device=device, dtype=dtype).reshape(1, -1, 1)
|
||||||
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(0, w_len - 1, steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1)
|
img_ids[:, :, :, 2] = img_ids[:, :, :, 2] + torch.linspace(w_start, w_start + (w_len - 1), steps=steps_w, device=device, dtype=dtype).reshape(1, 1, -1)
|
||||||
img_ids = img_ids.reshape(1, -1, img_ids.shape[-1])
|
img_ids = img_ids.reshape(1, -1, img_ids.shape[-1])
|
||||||
|
|
||||||
freqs = self.rope_embedder(img_ids).movedim(1, 2)
|
freqs = self.rope_embedder(img_ids).movedim(1, 2)
|
||||||
@ -630,7 +642,7 @@ class WanModel(torch.nn.Module):
|
|||||||
if self.ref_conv is not None and "reference_latent" in kwargs:
|
if self.ref_conv is not None and "reference_latent" in kwargs:
|
||||||
t_len += 1
|
t_len += 1
|
||||||
|
|
||||||
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype)
|
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options)
|
||||||
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, **kwargs)[:, :, :t, :h, :w]
|
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, **kwargs)[:, :, :t, :h, :w]
|
||||||
|
|
||||||
def unpatchify(self, x, grid_sizes):
|
def unpatchify(self, x, grid_sizes):
|
||||||
|
|||||||
@ -276,6 +276,9 @@ class ModelPatcher:
|
|||||||
self.size = comfy.model_management.module_size(self.model)
|
self.size = comfy.model_management.module_size(self.model)
|
||||||
return self.size
|
return self.size
|
||||||
|
|
||||||
|
def get_ram_usage(self):
|
||||||
|
return self.model_size()
|
||||||
|
|
||||||
def loaded_size(self):
|
def loaded_size(self):
|
||||||
return self.model.model_loaded_weight_memory
|
return self.model.model_loaded_weight_memory
|
||||||
|
|
||||||
@ -451,6 +454,19 @@ class ModelPatcher:
|
|||||||
def set_model_post_input_patch(self, patch):
|
def set_model_post_input_patch(self, patch):
|
||||||
self.set_model_patch(patch, "post_input")
|
self.set_model_patch(patch, "post_input")
|
||||||
|
|
||||||
|
def set_model_rope_options(self, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t, **kwargs):
|
||||||
|
rope_options = self.model_options["transformer_options"].get("rope_options", {})
|
||||||
|
rope_options["scale_x"] = scale_x
|
||||||
|
rope_options["scale_y"] = scale_y
|
||||||
|
rope_options["scale_t"] = scale_t
|
||||||
|
|
||||||
|
rope_options["shift_x"] = shift_x
|
||||||
|
rope_options["shift_y"] = shift_y
|
||||||
|
rope_options["shift_t"] = shift_t
|
||||||
|
|
||||||
|
self.model_options["transformer_options"]["rope_options"] = rope_options
|
||||||
|
|
||||||
|
|
||||||
def add_object_patch(self, name, obj):
|
def add_object_patch(self, name, obj):
|
||||||
self.object_patches[name] = obj
|
self.object_patches[name] = obj
|
||||||
|
|
||||||
|
|||||||
14
comfy/sd.py
14
comfy/sd.py
@ -143,6 +143,9 @@ class CLIP:
|
|||||||
n.apply_hooks_to_conds = self.apply_hooks_to_conds
|
n.apply_hooks_to_conds = self.apply_hooks_to_conds
|
||||||
return n
|
return n
|
||||||
|
|
||||||
|
def get_ram_usage(self):
|
||||||
|
return self.patcher.get_ram_usage()
|
||||||
|
|
||||||
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
|
def add_patches(self, patches, strength_patch=1.0, strength_model=1.0):
|
||||||
return self.patcher.add_patches(patches, strength_patch, strength_model)
|
return self.patcher.add_patches(patches, strength_patch, strength_model)
|
||||||
|
|
||||||
@ -293,6 +296,7 @@ class VAE:
|
|||||||
self.working_dtypes = [torch.bfloat16, torch.float32]
|
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||||
self.disable_offload = False
|
self.disable_offload = False
|
||||||
self.not_video = False
|
self.not_video = False
|
||||||
|
self.size = None
|
||||||
|
|
||||||
self.downscale_index_formula = None
|
self.downscale_index_formula = None
|
||||||
self.upscale_index_formula = None
|
self.upscale_index_formula = None
|
||||||
@ -595,6 +599,16 @@ class VAE:
|
|||||||
|
|
||||||
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
|
self.patcher = comfy.model_patcher.ModelPatcher(self.first_stage_model, load_device=self.device, offload_device=offload_device)
|
||||||
logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
|
logging.info("VAE load device: {}, offload device: {}, dtype: {}".format(self.device, offload_device, self.vae_dtype))
|
||||||
|
self.model_size()
|
||||||
|
|
||||||
|
def model_size(self):
|
||||||
|
if self.size is not None:
|
||||||
|
return self.size
|
||||||
|
self.size = comfy.model_management.module_size(self.first_stage_model)
|
||||||
|
return self.size
|
||||||
|
|
||||||
|
def get_ram_usage(self):
|
||||||
|
return self.model_size()
|
||||||
|
|
||||||
def throw_exception_if_invalid(self):
|
def throw_exception_if_invalid(self):
|
||||||
if self.first_stage_model is None:
|
if self.first_stage_model is None:
|
||||||
|
|||||||
@ -225,7 +225,7 @@ class OpenAIDalle2(ComfyNodeABC):
|
|||||||
),
|
),
|
||||||
files=(
|
files=(
|
||||||
{
|
{
|
||||||
"image": img_binary,
|
"image": ("image.png", img_binary, "image/png"),
|
||||||
}
|
}
|
||||||
if img_binary
|
if img_binary
|
||||||
else None
|
else None
|
||||||
|
|||||||
@ -1,4 +1,9 @@
|
|||||||
|
import bisect
|
||||||
|
import gc
|
||||||
import itertools
|
import itertools
|
||||||
|
import psutil
|
||||||
|
import time
|
||||||
|
import torch
|
||||||
from typing import Sequence, Mapping, Dict
|
from typing import Sequence, Mapping, Dict
|
||||||
from comfy_execution.graph import DynamicPrompt
|
from comfy_execution.graph import DynamicPrompt
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
@ -188,6 +193,9 @@ class BasicCache:
|
|||||||
self._clean_cache()
|
self._clean_cache()
|
||||||
self._clean_subcaches()
|
self._clean_subcaches()
|
||||||
|
|
||||||
|
def poll(self, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
def _set_immediate(self, node_id, value):
|
def _set_immediate(self, node_id, value):
|
||||||
assert self.initialized
|
assert self.initialized
|
||||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||||
@ -276,6 +284,9 @@ class NullCache:
|
|||||||
def clean_unused(self):
|
def clean_unused(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def poll(self, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
def get(self, node_id):
|
def get(self, node_id):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -336,3 +347,75 @@ class LRUCache(BasicCache):
|
|||||||
self._mark_used(child_id)
|
self._mark_used(child_id)
|
||||||
self.children[cache_key].append(self.cache_key_set.get_data_key(child_id))
|
self.children[cache_key].append(self.cache_key_set.get_data_key(child_id))
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
#Iterating the cache for usage analysis might be expensive, so if we trigger make sure
|
||||||
|
#to take a chunk out to give breathing space on high-node / low-ram-per-node flows.
|
||||||
|
|
||||||
|
RAM_CACHE_HYSTERESIS = 1.1
|
||||||
|
|
||||||
|
#This is kinda in GB but not really. It needs to be non-zero for the below heuristic
|
||||||
|
#and as long as Multi GB models dwarf this it will approximate OOM scoring OK
|
||||||
|
|
||||||
|
RAM_CACHE_DEFAULT_RAM_USAGE = 0.1
|
||||||
|
|
||||||
|
#Exponential bias towards evicting older workflows so garbage will be taken out
|
||||||
|
#in constantly changing setups.
|
||||||
|
|
||||||
|
RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER = 1.3
|
||||||
|
|
||||||
|
class RAMPressureCache(LRUCache):
|
||||||
|
|
||||||
|
def __init__(self, key_class):
|
||||||
|
super().__init__(key_class, 0)
|
||||||
|
self.timestamps = {}
|
||||||
|
|
||||||
|
def clean_unused(self):
|
||||||
|
self._clean_subcaches()
|
||||||
|
|
||||||
|
def set(self, node_id, value):
|
||||||
|
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
|
||||||
|
super().set(node_id, value)
|
||||||
|
|
||||||
|
def get(self, node_id):
|
||||||
|
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
|
||||||
|
return super().get(node_id)
|
||||||
|
|
||||||
|
def poll(self, ram_headroom):
|
||||||
|
def _ram_gb():
|
||||||
|
return psutil.virtual_memory().available / (1024**3)
|
||||||
|
|
||||||
|
if _ram_gb() > ram_headroom:
|
||||||
|
return
|
||||||
|
gc.collect()
|
||||||
|
if _ram_gb() > ram_headroom:
|
||||||
|
return
|
||||||
|
|
||||||
|
clean_list = []
|
||||||
|
|
||||||
|
for key, (outputs, _), in self.cache.items():
|
||||||
|
oom_score = RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER ** (self.generation - self.used_generation[key])
|
||||||
|
|
||||||
|
ram_usage = RAM_CACHE_DEFAULT_RAM_USAGE
|
||||||
|
def scan_list_for_ram_usage(outputs):
|
||||||
|
nonlocal ram_usage
|
||||||
|
for output in outputs:
|
||||||
|
if isinstance(output, list):
|
||||||
|
scan_list_for_ram_usage(output)
|
||||||
|
elif isinstance(output, torch.Tensor) and output.device.type == 'cpu':
|
||||||
|
#score Tensors at a 50% discount for RAM usage as they are likely to
|
||||||
|
#be high value intermediates
|
||||||
|
ram_usage += (output.numel() * output.element_size()) * 0.5
|
||||||
|
elif hasattr(output, "get_ram_usage"):
|
||||||
|
ram_usage += output.get_ram_usage()
|
||||||
|
scan_list_for_ram_usage(outputs)
|
||||||
|
|
||||||
|
oom_score *= ram_usage
|
||||||
|
#In the case where we have no information on the node ram usage at all,
|
||||||
|
#break OOM score ties on the last touch timestamp (pure LRU)
|
||||||
|
bisect.insort(clean_list, (oom_score, self.timestamps[key], key))
|
||||||
|
|
||||||
|
while _ram_gb() < ram_headroom * RAM_CACHE_HYSTERESIS and clean_list:
|
||||||
|
_, _, key = clean_list.pop()
|
||||||
|
del self.cache[key]
|
||||||
|
gc.collect()
|
||||||
|
|||||||
@ -209,10 +209,15 @@ class ExecutionList(TopologicalSort):
|
|||||||
self.execution_cache_listeners[from_node_id] = set()
|
self.execution_cache_listeners[from_node_id] = set()
|
||||||
self.execution_cache_listeners[from_node_id].add(to_node_id)
|
self.execution_cache_listeners[from_node_id].add(to_node_id)
|
||||||
|
|
||||||
def get_output_cache(self, from_node_id, to_node_id):
|
def get_cache(self, from_node_id, to_node_id):
|
||||||
if not to_node_id in self.execution_cache:
|
if not to_node_id in self.execution_cache:
|
||||||
return None
|
return None
|
||||||
return self.execution_cache[to_node_id].get(from_node_id)
|
value = self.execution_cache[to_node_id].get(from_node_id)
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
#Write back to the main cache on touch.
|
||||||
|
self.output_cache.set(from_node_id, value)
|
||||||
|
return value
|
||||||
|
|
||||||
def cache_update(self, node_id, value):
|
def cache_update(self, node_id, value):
|
||||||
if node_id in self.execution_cache_listeners:
|
if node_id in self.execution_cache_listeners:
|
||||||
|
|||||||
47
comfy_extras/nodes_rope.py
Normal file
47
comfy_extras/nodes_rope.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
|
|
||||||
|
class ScaleROPE(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return io.Schema(
|
||||||
|
node_id="ScaleROPE",
|
||||||
|
category="advanced/model_patches",
|
||||||
|
description="Scale and shift the ROPE of the model.",
|
||||||
|
is_experimental=True,
|
||||||
|
inputs=[
|
||||||
|
io.Model.Input("model"),
|
||||||
|
io.Float.Input("scale_x", default=1.0, min=0.0, max=100.0, step=0.1),
|
||||||
|
io.Float.Input("shift_x", default=0.0, min=-256.0, max=256.0, step=0.1),
|
||||||
|
|
||||||
|
io.Float.Input("scale_y", default=1.0, min=0.0, max=100.0, step=0.1),
|
||||||
|
io.Float.Input("shift_y", default=0.0, min=-256.0, max=256.0, step=0.1),
|
||||||
|
|
||||||
|
io.Float.Input("scale_t", default=1.0, min=0.0, max=100.0, step=0.1),
|
||||||
|
io.Float.Input("shift_t", default=0.0, min=-256.0, max=256.0, step=0.1),
|
||||||
|
|
||||||
|
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Model.Output(),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, model, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t) -> io.NodeOutput:
|
||||||
|
m = model.clone()
|
||||||
|
m.set_model_rope_options(scale_x, shift_x, scale_y, shift_y, scale_t, shift_t)
|
||||||
|
return io.NodeOutput(m)
|
||||||
|
|
||||||
|
|
||||||
|
class RopeExtension(ComfyExtension):
|
||||||
|
@override
|
||||||
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
|
return [
|
||||||
|
ScaleROPE
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> RopeExtension:
|
||||||
|
return RopeExtension()
|
||||||
81
execution.py
81
execution.py
@ -21,6 +21,7 @@ from comfy_execution.caching import (
|
|||||||
NullCache,
|
NullCache,
|
||||||
HierarchicalCache,
|
HierarchicalCache,
|
||||||
LRUCache,
|
LRUCache,
|
||||||
|
RAMPressureCache,
|
||||||
)
|
)
|
||||||
from comfy_execution.graph import (
|
from comfy_execution.graph import (
|
||||||
DynamicPrompt,
|
DynamicPrompt,
|
||||||
@ -88,49 +89,56 @@ class IsChangedCache:
|
|||||||
return self.is_changed[node_id]
|
return self.is_changed[node_id]
|
||||||
|
|
||||||
|
|
||||||
|
class CacheEntry(NamedTuple):
|
||||||
|
ui: dict
|
||||||
|
outputs: list
|
||||||
|
|
||||||
|
|
||||||
class CacheType(Enum):
|
class CacheType(Enum):
|
||||||
CLASSIC = 0
|
CLASSIC = 0
|
||||||
LRU = 1
|
LRU = 1
|
||||||
NONE = 2
|
NONE = 2
|
||||||
|
RAM_PRESSURE = 3
|
||||||
|
|
||||||
|
|
||||||
class CacheSet:
|
class CacheSet:
|
||||||
def __init__(self, cache_type=None, cache_size=None):
|
def __init__(self, cache_type=None, cache_args={}):
|
||||||
if cache_type == CacheType.NONE:
|
if cache_type == CacheType.NONE:
|
||||||
self.init_null_cache()
|
self.init_null_cache()
|
||||||
logging.info("Disabling intermediate node cache.")
|
logging.info("Disabling intermediate node cache.")
|
||||||
|
elif cache_type == CacheType.RAM_PRESSURE:
|
||||||
|
cache_ram = cache_args.get("ram", 16.0)
|
||||||
|
self.init_ram_cache(cache_ram)
|
||||||
|
logging.info("Using RAM pressure cache.")
|
||||||
elif cache_type == CacheType.LRU:
|
elif cache_type == CacheType.LRU:
|
||||||
if cache_size is None:
|
cache_size = cache_args.get("lru", 0)
|
||||||
cache_size = 0
|
|
||||||
self.init_lru_cache(cache_size)
|
self.init_lru_cache(cache_size)
|
||||||
logging.info("Using LRU cache")
|
logging.info("Using LRU cache")
|
||||||
else:
|
else:
|
||||||
self.init_classic_cache()
|
self.init_classic_cache()
|
||||||
|
|
||||||
self.all = [self.outputs, self.ui, self.objects]
|
self.all = [self.outputs, self.objects]
|
||||||
|
|
||||||
# Performs like the old cache -- dump data ASAP
|
# Performs like the old cache -- dump data ASAP
|
||||||
def init_classic_cache(self):
|
def init_classic_cache(self):
|
||||||
self.outputs = HierarchicalCache(CacheKeySetInputSignature)
|
self.outputs = HierarchicalCache(CacheKeySetInputSignature)
|
||||||
self.ui = HierarchicalCache(CacheKeySetInputSignature)
|
|
||||||
self.objects = HierarchicalCache(CacheKeySetID)
|
self.objects = HierarchicalCache(CacheKeySetID)
|
||||||
|
|
||||||
def init_lru_cache(self, cache_size):
|
def init_lru_cache(self, cache_size):
|
||||||
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
|
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
|
||||||
self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
|
self.objects = HierarchicalCache(CacheKeySetID)
|
||||||
|
|
||||||
|
def init_ram_cache(self, min_headroom):
|
||||||
|
self.outputs = RAMPressureCache(CacheKeySetInputSignature)
|
||||||
self.objects = HierarchicalCache(CacheKeySetID)
|
self.objects = HierarchicalCache(CacheKeySetID)
|
||||||
|
|
||||||
def init_null_cache(self):
|
def init_null_cache(self):
|
||||||
self.outputs = NullCache()
|
self.outputs = NullCache()
|
||||||
#The UI cache is expected to be iterable at the end of each workflow
|
|
||||||
#so it must cache at least a full workflow. Use Heirachical
|
|
||||||
self.ui = HierarchicalCache(CacheKeySetInputSignature)
|
|
||||||
self.objects = NullCache()
|
self.objects = NullCache()
|
||||||
|
|
||||||
def recursive_debug_dump(self):
|
def recursive_debug_dump(self):
|
||||||
result = {
|
result = {
|
||||||
"outputs": self.outputs.recursive_debug_dump(),
|
"outputs": self.outputs.recursive_debug_dump(),
|
||||||
"ui": self.ui.recursive_debug_dump(),
|
|
||||||
}
|
}
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@ -157,14 +165,14 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=
|
|||||||
if execution_list is None:
|
if execution_list is None:
|
||||||
mark_missing()
|
mark_missing()
|
||||||
continue # This might be a lazily-evaluated input
|
continue # This might be a lazily-evaluated input
|
||||||
cached_output = execution_list.get_output_cache(input_unique_id, unique_id)
|
cached = execution_list.get_cache(input_unique_id, unique_id)
|
||||||
if cached_output is None:
|
if cached is None or cached.outputs is None:
|
||||||
mark_missing()
|
mark_missing()
|
||||||
continue
|
continue
|
||||||
if output_index >= len(cached_output):
|
if output_index >= len(cached.outputs):
|
||||||
mark_missing()
|
mark_missing()
|
||||||
continue
|
continue
|
||||||
obj = cached_output[output_index]
|
obj = cached.outputs[output_index]
|
||||||
input_data_all[x] = obj
|
input_data_all[x] = obj
|
||||||
elif input_category is not None:
|
elif input_category is not None:
|
||||||
input_data_all[x] = [input_data]
|
input_data_all[x] = [input_data]
|
||||||
@ -393,7 +401,7 @@ def format_value(x):
|
|||||||
else:
|
else:
|
||||||
return str(x)
|
return str(x)
|
||||||
|
|
||||||
async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes):
|
async def execute(server, dynprompt, caches, current_item, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_outputs):
|
||||||
unique_id = current_item
|
unique_id = current_item
|
||||||
real_node_id = dynprompt.get_real_node_id(unique_id)
|
real_node_id = dynprompt.get_real_node_id(unique_id)
|
||||||
display_node_id = dynprompt.get_display_node_id(unique_id)
|
display_node_id = dynprompt.get_display_node_id(unique_id)
|
||||||
@ -401,12 +409,15 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
inputs = dynprompt.get_node(unique_id)['inputs']
|
inputs = dynprompt.get_node(unique_id)['inputs']
|
||||||
class_type = dynprompt.get_node(unique_id)['class_type']
|
class_type = dynprompt.get_node(unique_id)['class_type']
|
||||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||||
if caches.outputs.get(unique_id) is not None:
|
cached = caches.outputs.get(unique_id)
|
||||||
|
if cached is not None:
|
||||||
if server.client_id is not None:
|
if server.client_id is not None:
|
||||||
cached_output = caches.ui.get(unique_id) or {}
|
cached_ui = cached.ui or {}
|
||||||
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_output.get("output",None), "prompt_id": prompt_id }, server.client_id)
|
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": cached_ui.get("output",None), "prompt_id": prompt_id }, server.client_id)
|
||||||
|
if cached.ui is not None:
|
||||||
|
ui_outputs[unique_id] = cached.ui
|
||||||
get_progress_state().finish_progress(unique_id)
|
get_progress_state().finish_progress(unique_id)
|
||||||
execution_list.cache_update(unique_id, caches.outputs.get(unique_id))
|
execution_list.cache_update(unique_id, cached)
|
||||||
return (ExecutionResult.SUCCESS, None, None)
|
return (ExecutionResult.SUCCESS, None, None)
|
||||||
|
|
||||||
input_data_all = None
|
input_data_all = None
|
||||||
@ -436,8 +447,8 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
for r in result:
|
for r in result:
|
||||||
if is_link(r):
|
if is_link(r):
|
||||||
source_node, source_output = r[0], r[1]
|
source_node, source_output = r[0], r[1]
|
||||||
node_output = execution_list.get_output_cache(source_node, unique_id)[source_output]
|
node_cached = execution_list.get_cache(source_node, unique_id)
|
||||||
for o in node_output:
|
for o in node_cached.outputs[source_output]:
|
||||||
resolved_output.append(o)
|
resolved_output.append(o)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@ -507,7 +518,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
asyncio.create_task(await_completion())
|
asyncio.create_task(await_completion())
|
||||||
return (ExecutionResult.PENDING, None, None)
|
return (ExecutionResult.PENDING, None, None)
|
||||||
if len(output_ui) > 0:
|
if len(output_ui) > 0:
|
||||||
caches.ui.set(unique_id, {
|
ui_outputs[unique_id] = {
|
||||||
"meta": {
|
"meta": {
|
||||||
"node_id": unique_id,
|
"node_id": unique_id,
|
||||||
"display_node": display_node_id,
|
"display_node": display_node_id,
|
||||||
@ -515,7 +526,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
"real_node_id": real_node_id,
|
"real_node_id": real_node_id,
|
||||||
},
|
},
|
||||||
"output": output_ui
|
"output": output_ui
|
||||||
})
|
}
|
||||||
if server.client_id is not None:
|
if server.client_id is not None:
|
||||||
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
|
server.send_sync("executed", { "node": unique_id, "display_node": display_node_id, "output": output_ui, "prompt_id": prompt_id }, server.client_id)
|
||||||
if has_subgraph:
|
if has_subgraph:
|
||||||
@ -554,8 +565,9 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
pending_subgraph_results[unique_id] = cached_outputs
|
pending_subgraph_results[unique_id] = cached_outputs
|
||||||
return (ExecutionResult.PENDING, None, None)
|
return (ExecutionResult.PENDING, None, None)
|
||||||
|
|
||||||
caches.outputs.set(unique_id, output_data)
|
cache_entry = CacheEntry(ui=ui_outputs.get(unique_id), outputs=output_data)
|
||||||
execution_list.cache_update(unique_id, output_data)
|
execution_list.cache_update(unique_id, cache_entry)
|
||||||
|
caches.outputs.set(unique_id, cache_entry)
|
||||||
|
|
||||||
except comfy.model_management.InterruptProcessingException as iex:
|
except comfy.model_management.InterruptProcessingException as iex:
|
||||||
logging.info("Processing interrupted")
|
logging.info("Processing interrupted")
|
||||||
@ -600,14 +612,14 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
return (ExecutionResult.SUCCESS, None, None)
|
return (ExecutionResult.SUCCESS, None, None)
|
||||||
|
|
||||||
class PromptExecutor:
|
class PromptExecutor:
|
||||||
def __init__(self, server, cache_type=False, cache_size=None):
|
def __init__(self, server, cache_type=False, cache_args=None):
|
||||||
self.cache_size = cache_size
|
self.cache_args = cache_args
|
||||||
self.cache_type = cache_type
|
self.cache_type = cache_type
|
||||||
self.server = server
|
self.server = server
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self.caches = CacheSet(cache_type=self.cache_type, cache_size=self.cache_size)
|
self.caches = CacheSet(cache_type=self.cache_type, cache_args=self.cache_args)
|
||||||
self.status_messages = []
|
self.status_messages = []
|
||||||
self.success = True
|
self.success = True
|
||||||
|
|
||||||
@ -682,6 +694,7 @@ class PromptExecutor:
|
|||||||
broadcast=False)
|
broadcast=False)
|
||||||
pending_subgraph_results = {}
|
pending_subgraph_results = {}
|
||||||
pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
|
pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
|
||||||
|
ui_node_outputs = {}
|
||||||
executed = set()
|
executed = set()
|
||||||
execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
|
execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
|
||||||
current_outputs = self.caches.outputs.all_node_ids()
|
current_outputs = self.caches.outputs.all_node_ids()
|
||||||
@ -695,7 +708,7 @@ class PromptExecutor:
|
|||||||
break
|
break
|
||||||
|
|
||||||
assert node_id is not None, "Node ID should not be None at this point"
|
assert node_id is not None, "Node ID should not be None at this point"
|
||||||
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes)
|
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs)
|
||||||
self.success = result != ExecutionResult.FAILURE
|
self.success = result != ExecutionResult.FAILURE
|
||||||
if result == ExecutionResult.FAILURE:
|
if result == ExecutionResult.FAILURE:
|
||||||
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
|
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
|
||||||
@ -704,18 +717,16 @@ class PromptExecutor:
|
|||||||
execution_list.unstage_node_execution()
|
execution_list.unstage_node_execution()
|
||||||
else: # result == ExecutionResult.SUCCESS:
|
else: # result == ExecutionResult.SUCCESS:
|
||||||
execution_list.complete_node_execution()
|
execution_list.complete_node_execution()
|
||||||
|
self.caches.outputs.poll(ram_headroom=self.cache_args["ram"])
|
||||||
else:
|
else:
|
||||||
# Only execute when the while-loop ends without break
|
# Only execute when the while-loop ends without break
|
||||||
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
|
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
|
||||||
|
|
||||||
ui_outputs = {}
|
ui_outputs = {}
|
||||||
meta_outputs = {}
|
meta_outputs = {}
|
||||||
all_node_ids = self.caches.ui.all_node_ids()
|
for node_id, ui_info in ui_node_outputs.items():
|
||||||
for node_id in all_node_ids:
|
ui_outputs[node_id] = ui_info["output"]
|
||||||
ui_info = self.caches.ui.get(node_id)
|
meta_outputs[node_id] = ui_info["meta"]
|
||||||
if ui_info is not None:
|
|
||||||
ui_outputs[node_id] = ui_info["output"]
|
|
||||||
meta_outputs[node_id] = ui_info["meta"]
|
|
||||||
self.history_result = {
|
self.history_result = {
|
||||||
"outputs": ui_outputs,
|
"outputs": ui_outputs,
|
||||||
"meta": meta_outputs,
|
"meta": meta_outputs,
|
||||||
|
|||||||
4
main.py
4
main.py
@ -172,10 +172,12 @@ def prompt_worker(q, server_instance):
|
|||||||
cache_type = execution.CacheType.CLASSIC
|
cache_type = execution.CacheType.CLASSIC
|
||||||
if args.cache_lru > 0:
|
if args.cache_lru > 0:
|
||||||
cache_type = execution.CacheType.LRU
|
cache_type = execution.CacheType.LRU
|
||||||
|
elif args.cache_ram > 0:
|
||||||
|
cache_type = execution.CacheType.RAM_PRESSURE
|
||||||
elif args.cache_none:
|
elif args.cache_none:
|
||||||
cache_type = execution.CacheType.NONE
|
cache_type = execution.CacheType.NONE
|
||||||
|
|
||||||
e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_size=args.cache_lru)
|
e = execution.PromptExecutor(server_instance, cache_type=cache_type, cache_args={ "lru" : args.cache_lru, "ram" : args.cache_ram } )
|
||||||
last_gc_collect = 0
|
last_gc_collect = 0
|
||||||
need_gc = False
|
need_gc = False
|
||||||
gc_collect_interval = 10.0
|
gc_collect_interval = 10.0
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user