mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-05 23:02:49 +08:00
Merge branch 'master' into feat/api-nodes/TopazVideo-Astra2
This commit is contained in:
commit
552f3a65a4
@ -16,6 +16,7 @@ from comfy.ldm.lightricks.model import (
|
|||||||
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
|
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
|
||||||
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
|
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
|
||||||
import comfy.ldm.common_dit
|
import comfy.ldm.common_dit
|
||||||
|
import comfy.model_prefetch
|
||||||
|
|
||||||
class CompressedTimestep:
|
class CompressedTimestep:
|
||||||
"""Store video timestep embeddings in compressed form using per-frame indexing."""
|
"""Store video timestep embeddings in compressed form using per-frame indexing."""
|
||||||
@ -907,9 +908,11 @@ class LTXAVModel(LTXVModel):
|
|||||||
"""Process transformer blocks for LTXAV."""
|
"""Process transformer blocks for LTXAV."""
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
blocks_replace = patches_replace.get("dit", {})
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
|
prefetch_queue = comfy.model_prefetch.make_prefetch_queue(list(self.transformer_blocks), vx.device, transformer_options)
|
||||||
|
|
||||||
# Process transformer blocks
|
# Process transformer blocks
|
||||||
for i, block in enumerate(self.transformer_blocks):
|
for i, block in enumerate(self.transformer_blocks):
|
||||||
|
comfy.model_prefetch.prefetch_queue_pop(prefetch_queue, vx.device, block)
|
||||||
if ("double_block", i) in blocks_replace:
|
if ("double_block", i) in blocks_replace:
|
||||||
|
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
@ -982,6 +985,8 @@ class LTXAVModel(LTXVModel):
|
|||||||
a_prompt_timestep=a_prompt_timestep,
|
a_prompt_timestep=a_prompt_timestep,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
comfy.model_prefetch.prefetch_queue_pop(prefetch_queue, vx.device, None)
|
||||||
|
|
||||||
return [vx, ax]
|
return [vx, ax]
|
||||||
|
|
||||||
def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs):
|
def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs):
|
||||||
|
|||||||
@ -17,6 +17,7 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
import comfy.memory_management
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.model_base
|
import comfy.model_base
|
||||||
@ -473,3 +474,17 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori
|
|||||||
weight = old_weight
|
weight = old_weight
|
||||||
|
|
||||||
return weight
|
return weight
|
||||||
|
|
||||||
|
def prefetch_prepared_value(value, allocate_buffer, stream):
|
||||||
|
if isinstance(value, torch.Tensor):
|
||||||
|
dest = allocate_buffer(comfy.memory_management.vram_aligned_size(value))
|
||||||
|
comfy.model_management.cast_to_gathered([value], dest, non_blocking=True, stream=stream)
|
||||||
|
return comfy.memory_management.interpret_gathered_like([value], dest)[0]
|
||||||
|
elif isinstance(value, weight_adapter.WeightAdapterBase):
|
||||||
|
return type(value)(value.loaded_keys, prefetch_prepared_value(value.weights, allocate_buffer, stream))
|
||||||
|
elif isinstance(value, tuple):
|
||||||
|
return tuple(prefetch_prepared_value(item, allocate_buffer, stream) for item in value)
|
||||||
|
elif isinstance(value, list):
|
||||||
|
return [prefetch_prepared_value(item, allocate_buffer, stream) for item in value]
|
||||||
|
|
||||||
|
return value
|
||||||
|
|||||||
@ -214,6 +214,11 @@ class BaseModel(torch.nn.Module):
|
|||||||
if "latent_shapes" in extra_conds:
|
if "latent_shapes" in extra_conds:
|
||||||
xc = utils.unpack_latents(xc, extra_conds.pop("latent_shapes"))
|
xc = utils.unpack_latents(xc, extra_conds.pop("latent_shapes"))
|
||||||
|
|
||||||
|
transformer_options = transformer_options.copy()
|
||||||
|
transformer_options["prefetch_dynamic_vbars"] = (
|
||||||
|
self.current_patcher is not None and self.current_patcher.is_dynamic()
|
||||||
|
)
|
||||||
|
|
||||||
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds)
|
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds)
|
||||||
if len(model_output) > 1 and not torch.is_tensor(model_output):
|
if len(model_output) > 1 and not torch.is_tensor(model_output):
|
||||||
model_output, _ = utils.pack_latents(model_output)
|
model_output, _ = utils.pack_latents(model_output)
|
||||||
|
|||||||
@ -31,6 +31,7 @@ from contextlib import nullcontext
|
|||||||
import comfy.memory_management
|
import comfy.memory_management
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.quant_ops
|
import comfy.quant_ops
|
||||||
|
import comfy_aimdo.vram_buffer
|
||||||
|
|
||||||
class VRAMState(Enum):
|
class VRAMState(Enum):
|
||||||
DISABLED = 0 #No vram present: no need to move models to vram
|
DISABLED = 0 #No vram present: no need to move models to vram
|
||||||
@ -1175,6 +1176,10 @@ stream_counters = {}
|
|||||||
|
|
||||||
STREAM_CAST_BUFFERS = {}
|
STREAM_CAST_BUFFERS = {}
|
||||||
LARGEST_CASTED_WEIGHT = (None, 0)
|
LARGEST_CASTED_WEIGHT = (None, 0)
|
||||||
|
STREAM_AIMDO_CAST_BUFFERS = {}
|
||||||
|
LARGEST_AIMDO_CASTED_WEIGHT = (None, 0)
|
||||||
|
|
||||||
|
DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE = 16 * 1024 ** 3
|
||||||
|
|
||||||
def get_cast_buffer(offload_stream, device, size, ref):
|
def get_cast_buffer(offload_stream, device, size, ref):
|
||||||
global LARGEST_CASTED_WEIGHT
|
global LARGEST_CASTED_WEIGHT
|
||||||
@ -1208,13 +1213,26 @@ def get_cast_buffer(offload_stream, device, size, ref):
|
|||||||
|
|
||||||
return cast_buffer
|
return cast_buffer
|
||||||
|
|
||||||
|
def get_aimdo_cast_buffer(offload_stream, device):
|
||||||
|
cast_buffer = STREAM_AIMDO_CAST_BUFFERS.get(offload_stream, None)
|
||||||
|
if cast_buffer is None:
|
||||||
|
cast_buffer = comfy_aimdo.vram_buffer.VRAMBuffer(DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE, device.index)
|
||||||
|
STREAM_AIMDO_CAST_BUFFERS[offload_stream] = cast_buffer
|
||||||
|
|
||||||
|
return cast_buffer
|
||||||
def reset_cast_buffers():
|
def reset_cast_buffers():
|
||||||
global LARGEST_CASTED_WEIGHT
|
global LARGEST_CASTED_WEIGHT
|
||||||
|
global LARGEST_AIMDO_CASTED_WEIGHT
|
||||||
|
|
||||||
LARGEST_CASTED_WEIGHT = (None, 0)
|
LARGEST_CASTED_WEIGHT = (None, 0)
|
||||||
for offload_stream in STREAM_CAST_BUFFERS:
|
LARGEST_AIMDO_CASTED_WEIGHT = (None, 0)
|
||||||
offload_stream.synchronize()
|
for offload_stream in set(STREAM_CAST_BUFFERS) | set(STREAM_AIMDO_CAST_BUFFERS):
|
||||||
|
if offload_stream is not None:
|
||||||
|
offload_stream.synchronize()
|
||||||
synchronize()
|
synchronize()
|
||||||
|
|
||||||
STREAM_CAST_BUFFERS.clear()
|
STREAM_CAST_BUFFERS.clear()
|
||||||
|
STREAM_AIMDO_CAST_BUFFERS.clear()
|
||||||
soft_empty_cache()
|
soft_empty_cache()
|
||||||
|
|
||||||
def get_offload_stream(device):
|
def get_offload_stream(device):
|
||||||
|
|||||||
@ -121,9 +121,20 @@ class LowVramPatch:
|
|||||||
self.patches = patches
|
self.patches = patches
|
||||||
self.convert_func = convert_func # TODO: remove
|
self.convert_func = convert_func # TODO: remove
|
||||||
self.set_func = set_func
|
self.set_func = set_func
|
||||||
|
self.prepared_patches = None
|
||||||
|
|
||||||
|
def prepare(self, allocate_buffer, stream):
|
||||||
|
self.prepared_patches = [
|
||||||
|
(patch[0], comfy.lora.prefetch_prepared_value(patch[1], allocate_buffer, stream), patch[2], patch[3], patch[4])
|
||||||
|
for patch in self.patches[self.key]
|
||||||
|
]
|
||||||
|
|
||||||
|
def clear_prepared(self):
|
||||||
|
self.prepared_patches = None
|
||||||
|
|
||||||
def __call__(self, weight):
|
def __call__(self, weight):
|
||||||
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype)
|
patches = self.prepared_patches if self.prepared_patches is not None else self.patches[self.key]
|
||||||
|
return comfy.lora.calculate_weight(patches, weight, self.key, intermediate_dtype=weight.dtype)
|
||||||
|
|
||||||
LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 2
|
LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 2
|
||||||
|
|
||||||
|
|||||||
65
comfy/model_prefetch.py
Normal file
65
comfy/model_prefetch.py
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
import comfy_aimdo.model_vbar
|
||||||
|
import comfy.model_management
|
||||||
|
import comfy.ops
|
||||||
|
|
||||||
|
PREFETCH_QUEUES = []
|
||||||
|
|
||||||
|
def cleanup_prefetched_modules(comfy_modules):
|
||||||
|
for s in comfy_modules:
|
||||||
|
prefetch = getattr(s, "_prefetch", None)
|
||||||
|
if prefetch is None:
|
||||||
|
continue
|
||||||
|
for param_key in ("weight", "bias"):
|
||||||
|
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
|
||||||
|
if lowvram_fn is not None:
|
||||||
|
lowvram_fn.clear_prepared()
|
||||||
|
if prefetch["signature"] is not None:
|
||||||
|
comfy_aimdo.model_vbar.vbar_unpin(s._v)
|
||||||
|
delattr(s, "_prefetch")
|
||||||
|
|
||||||
|
def cleanup_prefetch_queues():
|
||||||
|
global PREFETCH_QUEUES
|
||||||
|
|
||||||
|
for queue in PREFETCH_QUEUES:
|
||||||
|
for entry in queue:
|
||||||
|
if entry is None or not isinstance(entry, tuple):
|
||||||
|
continue
|
||||||
|
_, prefetch_state = entry
|
||||||
|
comfy_modules = prefetch_state[1]
|
||||||
|
if comfy_modules is not None:
|
||||||
|
cleanup_prefetched_modules(comfy_modules)
|
||||||
|
PREFETCH_QUEUES = []
|
||||||
|
|
||||||
|
def prefetch_queue_pop(queue, device, module):
|
||||||
|
if queue is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
consumed = queue.pop(0)
|
||||||
|
if consumed is not None:
|
||||||
|
offload_stream, prefetch_state = consumed
|
||||||
|
offload_stream.wait_stream(comfy.model_management.current_stream(device))
|
||||||
|
_, comfy_modules = prefetch_state
|
||||||
|
if comfy_modules is not None:
|
||||||
|
cleanup_prefetched_modules(comfy_modules)
|
||||||
|
|
||||||
|
prefetch = queue[0]
|
||||||
|
if prefetch is not None:
|
||||||
|
comfy_modules = []
|
||||||
|
for s in prefetch.modules():
|
||||||
|
if hasattr(s, "_v"):
|
||||||
|
comfy_modules.append(s)
|
||||||
|
|
||||||
|
offload_stream = comfy.ops.cast_modules_with_vbar(comfy_modules, None, device, None, True)
|
||||||
|
comfy.model_management.sync_stream(device, offload_stream)
|
||||||
|
queue[0] = (offload_stream, (prefetch, comfy_modules))
|
||||||
|
|
||||||
|
def make_prefetch_queue(queue, device, transformer_options):
|
||||||
|
if (not transformer_options.get("prefetch_dynamic_vbars", False)
|
||||||
|
or comfy.model_management.NUM_STREAMS == 0
|
||||||
|
or comfy.model_management.is_device_cpu(device)
|
||||||
|
or not comfy.model_management.device_supports_non_blocking(device)):
|
||||||
|
return None
|
||||||
|
|
||||||
|
queue = [None] + queue + [None]
|
||||||
|
PREFETCH_QUEUES.append(queue)
|
||||||
|
return queue
|
||||||
181
comfy/ops.py
181
comfy/ops.py
@ -86,38 +86,61 @@ def materialize_meta_param(s, param_keys):
|
|||||||
setattr(s, param_key, torch.nn.Parameter(torch.zeros(param.shape, dtype=param.dtype), requires_grad=param.requires_grad))
|
setattr(s, param_key, torch.nn.Parameter(torch.zeros(param.shape, dtype=param.dtype), requires_grad=param.requires_grad))
|
||||||
|
|
||||||
|
|
||||||
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant):
|
# FIXME: add n=1 cache hit fast path
|
||||||
#vbar doesn't support CPU weights, but some custom nodes have weird paths
|
def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blocking):
|
||||||
#that might switch the layer to the CPU and expect it to work. We have to take
|
|
||||||
#a clone conservatively as we are mmapped and some SFT files are packed misaligned
|
|
||||||
#If you are a custom node author reading this, please move your layer to the GPU
|
|
||||||
#or declare your ModelPatcher as CPU in the first place.
|
|
||||||
if comfy.model_management.is_device_cpu(device):
|
|
||||||
materialize_meta_param(s, ["weight", "bias"])
|
|
||||||
weight = s.weight.to(dtype=dtype, copy=True)
|
|
||||||
if isinstance(weight, QuantizedTensor):
|
|
||||||
weight = weight.dequantize()
|
|
||||||
bias = None
|
|
||||||
if s.bias is not None:
|
|
||||||
bias = s.bias.to(dtype=bias_dtype, copy=True)
|
|
||||||
return weight, bias, (None, None, None)
|
|
||||||
|
|
||||||
offload_stream = None
|
offload_stream = None
|
||||||
xfer_dest = None
|
cast_buffer = None
|
||||||
|
cast_buffer_offset = 0
|
||||||
|
|
||||||
|
def ensure_offload_stream(module, required_size, check_largest):
|
||||||
|
nonlocal offload_stream
|
||||||
|
nonlocal cast_buffer
|
||||||
|
|
||||||
|
if offload_stream is None:
|
||||||
|
offload_stream = comfy.model_management.get_offload_stream(device)
|
||||||
|
if offload_stream is None or not check_largest or len(comfy_modules) != 1:
|
||||||
|
return
|
||||||
|
|
||||||
|
current_size = 0 if cast_buffer is None else cast_buffer.size()
|
||||||
|
if current_size < required_size and module is comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT[0]:
|
||||||
|
offload_stream = comfy.model_management.get_offload_stream(device)
|
||||||
|
cast_buffer = None
|
||||||
|
if required_size > comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT[1]:
|
||||||
|
comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT = (module, required_size)
|
||||||
|
|
||||||
|
def get_cast_buffer(buffer_size):
|
||||||
|
nonlocal offload_stream
|
||||||
|
nonlocal cast_buffer
|
||||||
|
nonlocal cast_buffer_offset
|
||||||
|
|
||||||
|
if buffer_size == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if offload_stream is None:
|
||||||
|
return torch.empty((buffer_size,), dtype=torch.uint8, device=device)
|
||||||
|
|
||||||
|
cast_buffer = comfy.model_management.get_aimdo_cast_buffer(offload_stream, device)
|
||||||
|
buffer = comfy_aimdo.torch.aimdo_to_tensor(cast_buffer.get(buffer_size, cast_buffer_offset), device)
|
||||||
|
cast_buffer_offset += buffer_size
|
||||||
|
return buffer
|
||||||
|
|
||||||
|
for s in comfy_modules:
|
||||||
|
signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
|
||||||
|
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
|
||||||
|
prefetch = {
|
||||||
|
"signature": signature,
|
||||||
|
"resident": resident,
|
||||||
|
}
|
||||||
|
|
||||||
signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
|
|
||||||
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
|
|
||||||
if signature is not None:
|
|
||||||
if resident:
|
if resident:
|
||||||
weight = s._v_weight
|
s._prefetch = prefetch
|
||||||
bias = s._v_bias
|
continue
|
||||||
else:
|
|
||||||
xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device)
|
|
||||||
|
|
||||||
if not resident:
|
|
||||||
materialize_meta_param(s, ["weight", "bias"])
|
materialize_meta_param(s, ["weight", "bias"])
|
||||||
|
xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device) if signature is not None else None
|
||||||
cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ])
|
cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ])
|
||||||
cast_dest = None
|
cast_dest = None
|
||||||
|
needs_cast = False
|
||||||
|
|
||||||
xfer_source = [ s.weight, s.bias ]
|
xfer_source = [ s.weight, s.bias ]
|
||||||
|
|
||||||
@ -129,22 +152,15 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
|
|||||||
if data is None:
|
if data is None:
|
||||||
continue
|
continue
|
||||||
if data.dtype != geometry.dtype:
|
if data.dtype != geometry.dtype:
|
||||||
|
needs_cast = True
|
||||||
cast_dest = xfer_dest
|
cast_dest = xfer_dest
|
||||||
if cast_dest is None:
|
|
||||||
cast_dest = torch.empty((comfy.memory_management.vram_aligned_size(cast_geometry),), dtype=torch.uint8, device=device)
|
|
||||||
xfer_dest = None
|
xfer_dest = None
|
||||||
break
|
break
|
||||||
|
|
||||||
dest_size = comfy.memory_management.vram_aligned_size(xfer_source)
|
dest_size = comfy.memory_management.vram_aligned_size(xfer_source)
|
||||||
offload_stream = comfy.model_management.get_offload_stream(device)
|
ensure_offload_stream(s, dest_size if xfer_dest is None else 0, True)
|
||||||
if xfer_dest is None and offload_stream is not None:
|
|
||||||
xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s)
|
|
||||||
if xfer_dest is None:
|
|
||||||
offload_stream = comfy.model_management.get_offload_stream(device)
|
|
||||||
xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s)
|
|
||||||
if xfer_dest is None:
|
if xfer_dest is None:
|
||||||
xfer_dest = torch.empty((dest_size,), dtype=torch.uint8, device=device)
|
xfer_dest = get_cast_buffer(dest_size)
|
||||||
offload_stream = None
|
|
||||||
|
|
||||||
if signature is None and pin is None:
|
if signature is None and pin is None:
|
||||||
comfy.pinned_memory.pin_memory(s)
|
comfy.pinned_memory.pin_memory(s)
|
||||||
@ -157,27 +173,54 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
|
|||||||
xfer_source = [ pin ]
|
xfer_source = [ pin ]
|
||||||
#send it over
|
#send it over
|
||||||
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream)
|
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream)
|
||||||
comfy.model_management.sync_stream(device, offload_stream)
|
|
||||||
|
|
||||||
if cast_dest is not None:
|
for param_key in ("weight", "bias"):
|
||||||
|
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
|
||||||
|
if lowvram_fn is not None:
|
||||||
|
ensure_offload_stream(s, cast_buffer_offset, False)
|
||||||
|
lowvram_fn.prepare(lambda size: get_cast_buffer(size), offload_stream)
|
||||||
|
|
||||||
|
prefetch["xfer_dest"] = xfer_dest
|
||||||
|
prefetch["cast_dest"] = cast_dest
|
||||||
|
prefetch["cast_geometry"] = cast_geometry
|
||||||
|
prefetch["needs_cast"] = needs_cast
|
||||||
|
s._prefetch = prefetch
|
||||||
|
|
||||||
|
return offload_stream
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_cast_module_with_vbar(s, dtype, device, bias_dtype, compute_dtype, want_requant):
|
||||||
|
|
||||||
|
prefetch = getattr(s, "_prefetch", None)
|
||||||
|
|
||||||
|
if prefetch["resident"]:
|
||||||
|
weight = s._v_weight
|
||||||
|
bias = s._v_bias
|
||||||
|
else:
|
||||||
|
xfer_dest = prefetch["xfer_dest"]
|
||||||
|
if prefetch["needs_cast"]:
|
||||||
|
cast_dest = prefetch["cast_dest"] if prefetch["cast_dest"] is not None else torch.empty((comfy.memory_management.vram_aligned_size(prefetch["cast_geometry"]),), dtype=torch.uint8, device=device)
|
||||||
for pre_cast, post_cast in zip(comfy.memory_management.interpret_gathered_like([s.weight, s.bias ], xfer_dest),
|
for pre_cast, post_cast in zip(comfy.memory_management.interpret_gathered_like([s.weight, s.bias ], xfer_dest),
|
||||||
comfy.memory_management.interpret_gathered_like(cast_geometry, cast_dest)):
|
comfy.memory_management.interpret_gathered_like(prefetch["cast_geometry"], cast_dest)):
|
||||||
if post_cast is not None:
|
if post_cast is not None:
|
||||||
post_cast.copy_(pre_cast)
|
post_cast.copy_(pre_cast)
|
||||||
xfer_dest = cast_dest
|
xfer_dest = cast_dest
|
||||||
|
|
||||||
params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest)
|
params = comfy.memory_management.interpret_gathered_like(prefetch["cast_geometry"], xfer_dest)
|
||||||
weight = params[0]
|
weight = params[0]
|
||||||
bias = params[1]
|
bias = params[1]
|
||||||
if signature is not None:
|
if prefetch["signature"] is not None:
|
||||||
s._v_weight = weight
|
s._v_weight = weight
|
||||||
s._v_bias = bias
|
s._v_bias = bias
|
||||||
s._v_signature=signature
|
s._v_signature = prefetch["signature"]
|
||||||
|
|
||||||
def post_cast(s, param_key, x, dtype, resident, update_weight):
|
def post_cast(s, param_key, x, dtype, resident, update_weight):
|
||||||
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
|
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
|
||||||
fns = getattr(s, param_key + "_function", [])
|
fns = getattr(s, param_key + "_function", [])
|
||||||
|
|
||||||
|
if x is None:
|
||||||
|
return None
|
||||||
|
|
||||||
orig = x
|
orig = x
|
||||||
|
|
||||||
def to_dequant(tensor, dtype):
|
def to_dequant(tensor, dtype):
|
||||||
@ -205,14 +248,12 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
|
|||||||
x = f(x)
|
x = f(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
update_weight = signature is not None
|
update_weight = prefetch["signature"] is not None
|
||||||
|
weight = post_cast(s, "weight", weight, dtype, prefetch["resident"], update_weight)
|
||||||
|
if bias is not None:
|
||||||
|
bias = post_cast(s, "bias", bias, bias_dtype, prefetch["resident"], update_weight)
|
||||||
|
|
||||||
weight = post_cast(s, "weight", weight, dtype, resident, update_weight)
|
return weight, bias
|
||||||
if s.bias is not None:
|
|
||||||
bias = post_cast(s, "bias", bias, bias_dtype, resident, update_weight)
|
|
||||||
|
|
||||||
#FIXME: weird offload return protocol
|
|
||||||
return weight, bias, (offload_stream, device if signature is not None else None, None)
|
|
||||||
|
|
||||||
|
|
||||||
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None, want_requant=False):
|
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None, want_requant=False):
|
||||||
@ -230,10 +271,46 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
|||||||
if device is None:
|
if device is None:
|
||||||
device = input.device
|
device = input.device
|
||||||
|
|
||||||
|
def format_return(result, offloadable):
|
||||||
|
weight, bias, offload_stream = result
|
||||||
|
return (weight, bias, offload_stream) if offloadable else (weight, bias)
|
||||||
|
|
||||||
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
||||||
|
|
||||||
if hasattr(s, "_v"):
|
if hasattr(s, "_v"):
|
||||||
return cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant)
|
|
||||||
|
#vbar doesn't support CPU weights, but some custom nodes have weird paths
|
||||||
|
#that might switch the layer to the CPU and expect it to work. We have to take
|
||||||
|
#a clone conservatively as we are mmapped and some SFT files are packed misaligned
|
||||||
|
#If you are a custom node author reading this, please move your layer to the GPU
|
||||||
|
#or declare your ModelPatcher as CPU in the first place.
|
||||||
|
if comfy.model_management.is_device_cpu(device):
|
||||||
|
materialize_meta_param(s, ["weight", "bias"])
|
||||||
|
weight = s.weight.to(dtype=dtype, copy=True)
|
||||||
|
if isinstance(weight, QuantizedTensor):
|
||||||
|
weight = weight.dequantize()
|
||||||
|
bias = s.bias.to(dtype=bias_dtype, copy=True) if s.bias is not None else None
|
||||||
|
return format_return((weight, bias, (None, None, None)), offloadable)
|
||||||
|
|
||||||
|
prefetched = hasattr(s, "_prefetch")
|
||||||
|
offload_stream = None
|
||||||
|
offload_device = None
|
||||||
|
if not prefetched:
|
||||||
|
offload_stream = cast_modules_with_vbar([s], dtype, device, bias_dtype, non_blocking)
|
||||||
|
comfy.model_management.sync_stream(device, offload_stream)
|
||||||
|
|
||||||
|
weight, bias = resolve_cast_module_with_vbar(s, dtype, device, bias_dtype, compute_dtype, want_requant)
|
||||||
|
|
||||||
|
if not prefetched:
|
||||||
|
if getattr(s, "_prefetch")["signature"] is not None:
|
||||||
|
offload_device = device
|
||||||
|
for param_key in ("weight", "bias"):
|
||||||
|
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
|
||||||
|
if lowvram_fn is not None:
|
||||||
|
lowvram_fn.clear_prepared()
|
||||||
|
delattr(s, "_prefetch")
|
||||||
|
return format_return((weight, bias, (offload_stream, offload_device, None)), offloadable)
|
||||||
|
|
||||||
|
|
||||||
if offloadable and (device != s.weight.device or
|
if offloadable and (device != s.weight.device or
|
||||||
(s.bias is not None and device != s.bias.device)):
|
(s.bias is not None and device != s.bias.device)):
|
||||||
@ -280,11 +357,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
|||||||
for f in s.weight_function:
|
for f in s.weight_function:
|
||||||
weight = f(weight)
|
weight = f(weight)
|
||||||
|
|
||||||
if offloadable:
|
return format_return((weight, bias, (offload_stream, weight_a, bias_a)), offloadable)
|
||||||
return weight, bias, (offload_stream, weight_a, bias_a)
|
|
||||||
else:
|
|
||||||
#Legacy function signature
|
|
||||||
return weight, bias
|
|
||||||
|
|
||||||
|
|
||||||
def uncast_bias_weight(s, weight, bias, offload_stream):
|
def uncast_bias_weight(s, weight, bias, offload_stream):
|
||||||
|
|||||||
@ -15,6 +15,7 @@ import torch
|
|||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
import comfy.memory_management
|
import comfy.memory_management
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
import comfy.model_prefetch
|
||||||
import comfy_aimdo.model_vbar
|
import comfy_aimdo.model_vbar
|
||||||
|
|
||||||
from latent_preview import set_preview_method
|
from latent_preview import set_preview_method
|
||||||
@ -537,6 +538,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
if args.verbose == "DEBUG":
|
if args.verbose == "DEBUG":
|
||||||
comfy_aimdo.control.analyze()
|
comfy_aimdo.control.analyze()
|
||||||
comfy.model_management.reset_cast_buffers()
|
comfy.model_management.reset_cast_buffers()
|
||||||
|
comfy.model_prefetch.cleanup_prefetch_queues()
|
||||||
comfy_aimdo.model_vbar.vbars_reset_watermark_limits()
|
comfy_aimdo.model_vbar.vbars_reset_watermark_limits()
|
||||||
|
|
||||||
if has_pending_tasks:
|
if has_pending_tasks:
|
||||||
|
|||||||
29
nodes.py
29
nodes.py
@ -1694,26 +1694,27 @@ class LoadImage:
|
|||||||
|
|
||||||
RETURN_TYPES = ("IMAGE", "MASK")
|
RETURN_TYPES = ("IMAGE", "MASK")
|
||||||
FUNCTION = "load_image"
|
FUNCTION = "load_image"
|
||||||
|
|
||||||
def load_image(self, image):
|
def load_image(self, image):
|
||||||
image_path = folder_paths.get_annotated_filepath(image)
|
image_path = folder_paths.get_annotated_filepath(image)
|
||||||
|
|
||||||
|
dtype = comfy.model_management.intermediate_dtype()
|
||||||
|
device = comfy.model_management.intermediate_device()
|
||||||
|
|
||||||
components = InputImpl.VideoFromFile(image_path).get_components()
|
components = InputImpl.VideoFromFile(image_path).get_components()
|
||||||
if components.images.shape[0] > 0:
|
if components.images.shape[0] > 0:
|
||||||
return (components.images, 1.0 - components.alpha[..., -1] if components.alpha is not None else torch.zeros((components.images.shape[0], 64, 64), dtype=torch.float32, device="cpu"))
|
return (components.images.to(device=device, dtype=dtype), (1.0 - components.alpha[..., -1]).to(device=device, dtype=dtype) if components.alpha is not None else torch.zeros((components.images.shape[0], 64, 64), dtype=dtype, device=device))
|
||||||
|
|
||||||
|
# This code is left here to handle animated webp which pyav does not support loading
|
||||||
img = node_helpers.pillow(Image.open, image_path)
|
img = node_helpers.pillow(Image.open, image_path)
|
||||||
|
|
||||||
output_images = []
|
output_images = []
|
||||||
output_masks = []
|
output_masks = []
|
||||||
w, h = None, None
|
w, h = None, None
|
||||||
|
|
||||||
dtype = comfy.model_management.intermediate_dtype()
|
|
||||||
|
|
||||||
for i in ImageSequence.Iterator(img):
|
for i in ImageSequence.Iterator(img):
|
||||||
i = node_helpers.pillow(ImageOps.exif_transpose, i)
|
i = node_helpers.pillow(ImageOps.exif_transpose, i)
|
||||||
|
|
||||||
if i.mode == 'I':
|
|
||||||
i = i.point(lambda i: i * (1 / 255))
|
|
||||||
image = i.convert("RGB")
|
image = i.convert("RGB")
|
||||||
|
|
||||||
if len(output_images) == 0:
|
if len(output_images) == 0:
|
||||||
@ -1728,25 +1729,15 @@ class LoadImage:
|
|||||||
if 'A' in i.getbands():
|
if 'A' in i.getbands():
|
||||||
mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0
|
mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0
|
||||||
mask = 1. - torch.from_numpy(mask)
|
mask = 1. - torch.from_numpy(mask)
|
||||||
elif i.mode == 'P' and 'transparency' in i.info:
|
|
||||||
mask = np.array(i.convert('RGBA').getchannel('A')).astype(np.float32) / 255.0
|
|
||||||
mask = 1. - torch.from_numpy(mask)
|
|
||||||
else:
|
else:
|
||||||
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
|
mask = torch.zeros((64, 64), dtype=torch.float32, device="cpu")
|
||||||
output_images.append(image.to(dtype=dtype))
|
output_images.append(image.to(dtype=dtype))
|
||||||
output_masks.append(mask.unsqueeze(0).to(dtype=dtype))
|
output_masks.append(mask.unsqueeze(0).to(dtype=dtype))
|
||||||
|
|
||||||
if img.format == "MPO":
|
output_image = torch.cat(output_images, dim=0)
|
||||||
break # ignore all frames except the first one for MPO format
|
output_mask = torch.cat(output_masks, dim=0)
|
||||||
|
|
||||||
if len(output_images) > 1:
|
return (output_image.to(device=device, dtype=dtype), output_mask.to(device=device, dtype=dtype))
|
||||||
output_image = torch.cat(output_images, dim=0)
|
|
||||||
output_mask = torch.cat(output_masks, dim=0)
|
|
||||||
else:
|
|
||||||
output_image = output_images[0]
|
|
||||||
output_mask = output_masks[0]
|
|
||||||
|
|
||||||
return (output_image, output_mask)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def IS_CHANGED(s, image):
|
def IS_CHANGED(s, image):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user