mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-06 07:12:30 +08:00
Implement block prefetch + Lora Async load + and adopt in LTX (Speedup!) (CORE-111) (#13618)
* mm: Use Aimdo raw allocator for cast buffers pytorch manages allocation of growing buffers on streams poorly. Pyt has no windows support for the expandable segments allocator (which is the right tool for this job), while also segmenting the memory by stream such that it can be generally re-used. So kick the problem to aimdo which can just grow a virtual region thats freed per stream. * plan * ops: move cpu handler up to the caller * ops: split up prefetch from weight prep block prefetching API Split up the casting and weight formating/lora stuff in prep for arbitrary prefetch support. * ops: implement block prefetching API allow a model to construct a prefetch list and operate it for increased async offload. * ltxv2: Implement block prefetching * Implement lora async offload Implement async offload of loras.
This commit is contained in:
parent
3e3ed8cc2a
commit
783782d5d7
@ -16,6 +16,7 @@ from comfy.ldm.lightricks.model import (
|
|||||||
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
|
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
|
||||||
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
|
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
|
||||||
import comfy.ldm.common_dit
|
import comfy.ldm.common_dit
|
||||||
|
import comfy.model_prefetch
|
||||||
|
|
||||||
class CompressedTimestep:
|
class CompressedTimestep:
|
||||||
"""Store video timestep embeddings in compressed form using per-frame indexing."""
|
"""Store video timestep embeddings in compressed form using per-frame indexing."""
|
||||||
@ -907,9 +908,11 @@ class LTXAVModel(LTXVModel):
|
|||||||
"""Process transformer blocks for LTXAV."""
|
"""Process transformer blocks for LTXAV."""
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
blocks_replace = patches_replace.get("dit", {})
|
blocks_replace = patches_replace.get("dit", {})
|
||||||
|
prefetch_queue = comfy.model_prefetch.make_prefetch_queue(list(self.transformer_blocks), vx.device, transformer_options)
|
||||||
|
|
||||||
# Process transformer blocks
|
# Process transformer blocks
|
||||||
for i, block in enumerate(self.transformer_blocks):
|
for i, block in enumerate(self.transformer_blocks):
|
||||||
|
comfy.model_prefetch.prefetch_queue_pop(prefetch_queue, vx.device, block)
|
||||||
if ("double_block", i) in blocks_replace:
|
if ("double_block", i) in blocks_replace:
|
||||||
|
|
||||||
def block_wrap(args):
|
def block_wrap(args):
|
||||||
@ -982,6 +985,8 @@ class LTXAVModel(LTXVModel):
|
|||||||
a_prompt_timestep=a_prompt_timestep,
|
a_prompt_timestep=a_prompt_timestep,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
comfy.model_prefetch.prefetch_queue_pop(prefetch_queue, vx.device, None)
|
||||||
|
|
||||||
return [vx, ax]
|
return [vx, ax]
|
||||||
|
|
||||||
def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs):
|
def _process_output(self, x, embedded_timestep, keyframe_idxs, **kwargs):
|
||||||
|
|||||||
@ -17,6 +17,7 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
import comfy.memory_management
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.model_base
|
import comfy.model_base
|
||||||
@ -473,3 +474,17 @@ def calculate_weight(patches, weight, key, intermediate_dtype=torch.float32, ori
|
|||||||
weight = old_weight
|
weight = old_weight
|
||||||
|
|
||||||
return weight
|
return weight
|
||||||
|
|
||||||
|
def prefetch_prepared_value(value, allocate_buffer, stream):
|
||||||
|
if isinstance(value, torch.Tensor):
|
||||||
|
dest = allocate_buffer(comfy.memory_management.vram_aligned_size(value))
|
||||||
|
comfy.model_management.cast_to_gathered([value], dest, non_blocking=True, stream=stream)
|
||||||
|
return comfy.memory_management.interpret_gathered_like([value], dest)[0]
|
||||||
|
elif isinstance(value, weight_adapter.WeightAdapterBase):
|
||||||
|
return type(value)(value.loaded_keys, prefetch_prepared_value(value.weights, allocate_buffer, stream))
|
||||||
|
elif isinstance(value, tuple):
|
||||||
|
return tuple(prefetch_prepared_value(item, allocate_buffer, stream) for item in value)
|
||||||
|
elif isinstance(value, list):
|
||||||
|
return [prefetch_prepared_value(item, allocate_buffer, stream) for item in value]
|
||||||
|
|
||||||
|
return value
|
||||||
|
|||||||
@ -214,6 +214,11 @@ class BaseModel(torch.nn.Module):
|
|||||||
if "latent_shapes" in extra_conds:
|
if "latent_shapes" in extra_conds:
|
||||||
xc = utils.unpack_latents(xc, extra_conds.pop("latent_shapes"))
|
xc = utils.unpack_latents(xc, extra_conds.pop("latent_shapes"))
|
||||||
|
|
||||||
|
transformer_options = transformer_options.copy()
|
||||||
|
transformer_options["prefetch_dynamic_vbars"] = (
|
||||||
|
self.current_patcher is not None and self.current_patcher.is_dynamic()
|
||||||
|
)
|
||||||
|
|
||||||
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds)
|
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds)
|
||||||
if len(model_output) > 1 and not torch.is_tensor(model_output):
|
if len(model_output) > 1 and not torch.is_tensor(model_output):
|
||||||
model_output, _ = utils.pack_latents(model_output)
|
model_output, _ = utils.pack_latents(model_output)
|
||||||
|
|||||||
@ -31,6 +31,7 @@ from contextlib import nullcontext
|
|||||||
import comfy.memory_management
|
import comfy.memory_management
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.quant_ops
|
import comfy.quant_ops
|
||||||
|
import comfy_aimdo.vram_buffer
|
||||||
|
|
||||||
class VRAMState(Enum):
|
class VRAMState(Enum):
|
||||||
DISABLED = 0 #No vram present: no need to move models to vram
|
DISABLED = 0 #No vram present: no need to move models to vram
|
||||||
@ -1175,6 +1176,10 @@ stream_counters = {}
|
|||||||
|
|
||||||
STREAM_CAST_BUFFERS = {}
|
STREAM_CAST_BUFFERS = {}
|
||||||
LARGEST_CASTED_WEIGHT = (None, 0)
|
LARGEST_CASTED_WEIGHT = (None, 0)
|
||||||
|
STREAM_AIMDO_CAST_BUFFERS = {}
|
||||||
|
LARGEST_AIMDO_CASTED_WEIGHT = (None, 0)
|
||||||
|
|
||||||
|
DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE = 16 * 1024 ** 3
|
||||||
|
|
||||||
def get_cast_buffer(offload_stream, device, size, ref):
|
def get_cast_buffer(offload_stream, device, size, ref):
|
||||||
global LARGEST_CASTED_WEIGHT
|
global LARGEST_CASTED_WEIGHT
|
||||||
@ -1208,13 +1213,26 @@ def get_cast_buffer(offload_stream, device, size, ref):
|
|||||||
|
|
||||||
return cast_buffer
|
return cast_buffer
|
||||||
|
|
||||||
|
def get_aimdo_cast_buffer(offload_stream, device):
|
||||||
|
cast_buffer = STREAM_AIMDO_CAST_BUFFERS.get(offload_stream, None)
|
||||||
|
if cast_buffer is None:
|
||||||
|
cast_buffer = comfy_aimdo.vram_buffer.VRAMBuffer(DEFAULT_AIMDO_CAST_BUFFER_RESERVATION_SIZE, device.index)
|
||||||
|
STREAM_AIMDO_CAST_BUFFERS[offload_stream] = cast_buffer
|
||||||
|
|
||||||
|
return cast_buffer
|
||||||
def reset_cast_buffers():
|
def reset_cast_buffers():
|
||||||
global LARGEST_CASTED_WEIGHT
|
global LARGEST_CASTED_WEIGHT
|
||||||
|
global LARGEST_AIMDO_CASTED_WEIGHT
|
||||||
|
|
||||||
LARGEST_CASTED_WEIGHT = (None, 0)
|
LARGEST_CASTED_WEIGHT = (None, 0)
|
||||||
for offload_stream in STREAM_CAST_BUFFERS:
|
LARGEST_AIMDO_CASTED_WEIGHT = (None, 0)
|
||||||
offload_stream.synchronize()
|
for offload_stream in set(STREAM_CAST_BUFFERS) | set(STREAM_AIMDO_CAST_BUFFERS):
|
||||||
|
if offload_stream is not None:
|
||||||
|
offload_stream.synchronize()
|
||||||
synchronize()
|
synchronize()
|
||||||
|
|
||||||
STREAM_CAST_BUFFERS.clear()
|
STREAM_CAST_BUFFERS.clear()
|
||||||
|
STREAM_AIMDO_CAST_BUFFERS.clear()
|
||||||
soft_empty_cache()
|
soft_empty_cache()
|
||||||
|
|
||||||
def get_offload_stream(device):
|
def get_offload_stream(device):
|
||||||
|
|||||||
@ -121,9 +121,20 @@ class LowVramPatch:
|
|||||||
self.patches = patches
|
self.patches = patches
|
||||||
self.convert_func = convert_func # TODO: remove
|
self.convert_func = convert_func # TODO: remove
|
||||||
self.set_func = set_func
|
self.set_func = set_func
|
||||||
|
self.prepared_patches = None
|
||||||
|
|
||||||
|
def prepare(self, allocate_buffer, stream):
|
||||||
|
self.prepared_patches = [
|
||||||
|
(patch[0], comfy.lora.prefetch_prepared_value(patch[1], allocate_buffer, stream), patch[2], patch[3], patch[4])
|
||||||
|
for patch in self.patches[self.key]
|
||||||
|
]
|
||||||
|
|
||||||
|
def clear_prepared(self):
|
||||||
|
self.prepared_patches = None
|
||||||
|
|
||||||
def __call__(self, weight):
|
def __call__(self, weight):
|
||||||
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype)
|
patches = self.prepared_patches if self.prepared_patches is not None else self.patches[self.key]
|
||||||
|
return comfy.lora.calculate_weight(patches, weight, self.key, intermediate_dtype=weight.dtype)
|
||||||
|
|
||||||
LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 2
|
LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 2
|
||||||
|
|
||||||
|
|||||||
65
comfy/model_prefetch.py
Normal file
65
comfy/model_prefetch.py
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
import comfy_aimdo.model_vbar
|
||||||
|
import comfy.model_management
|
||||||
|
import comfy.ops
|
||||||
|
|
||||||
|
PREFETCH_QUEUES = []
|
||||||
|
|
||||||
|
def cleanup_prefetched_modules(comfy_modules):
|
||||||
|
for s in comfy_modules:
|
||||||
|
prefetch = getattr(s, "_prefetch", None)
|
||||||
|
if prefetch is None:
|
||||||
|
continue
|
||||||
|
for param_key in ("weight", "bias"):
|
||||||
|
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
|
||||||
|
if lowvram_fn is not None:
|
||||||
|
lowvram_fn.clear_prepared()
|
||||||
|
if prefetch["signature"] is not None:
|
||||||
|
comfy_aimdo.model_vbar.vbar_unpin(s._v)
|
||||||
|
delattr(s, "_prefetch")
|
||||||
|
|
||||||
|
def cleanup_prefetch_queues():
|
||||||
|
global PREFETCH_QUEUES
|
||||||
|
|
||||||
|
for queue in PREFETCH_QUEUES:
|
||||||
|
for entry in queue:
|
||||||
|
if entry is None or not isinstance(entry, tuple):
|
||||||
|
continue
|
||||||
|
_, prefetch_state = entry
|
||||||
|
comfy_modules = prefetch_state[1]
|
||||||
|
if comfy_modules is not None:
|
||||||
|
cleanup_prefetched_modules(comfy_modules)
|
||||||
|
PREFETCH_QUEUES = []
|
||||||
|
|
||||||
|
def prefetch_queue_pop(queue, device, module):
|
||||||
|
if queue is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
consumed = queue.pop(0)
|
||||||
|
if consumed is not None:
|
||||||
|
offload_stream, prefetch_state = consumed
|
||||||
|
offload_stream.wait_stream(comfy.model_management.current_stream(device))
|
||||||
|
_, comfy_modules = prefetch_state
|
||||||
|
if comfy_modules is not None:
|
||||||
|
cleanup_prefetched_modules(comfy_modules)
|
||||||
|
|
||||||
|
prefetch = queue[0]
|
||||||
|
if prefetch is not None:
|
||||||
|
comfy_modules = []
|
||||||
|
for s in prefetch.modules():
|
||||||
|
if hasattr(s, "_v"):
|
||||||
|
comfy_modules.append(s)
|
||||||
|
|
||||||
|
offload_stream = comfy.ops.cast_modules_with_vbar(comfy_modules, None, device, None, True)
|
||||||
|
comfy.model_management.sync_stream(device, offload_stream)
|
||||||
|
queue[0] = (offload_stream, (prefetch, comfy_modules))
|
||||||
|
|
||||||
|
def make_prefetch_queue(queue, device, transformer_options):
|
||||||
|
if (not transformer_options.get("prefetch_dynamic_vbars", False)
|
||||||
|
or comfy.model_management.NUM_STREAMS == 0
|
||||||
|
or comfy.model_management.is_device_cpu(device)
|
||||||
|
or not comfy.model_management.device_supports_non_blocking(device)):
|
||||||
|
return None
|
||||||
|
|
||||||
|
queue = [None] + queue + [None]
|
||||||
|
PREFETCH_QUEUES.append(queue)
|
||||||
|
return queue
|
||||||
181
comfy/ops.py
181
comfy/ops.py
@ -86,38 +86,61 @@ def materialize_meta_param(s, param_keys):
|
|||||||
setattr(s, param_key, torch.nn.Parameter(torch.zeros(param.shape, dtype=param.dtype), requires_grad=param.requires_grad))
|
setattr(s, param_key, torch.nn.Parameter(torch.zeros(param.shape, dtype=param.dtype), requires_grad=param.requires_grad))
|
||||||
|
|
||||||
|
|
||||||
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant):
|
# FIXME: add n=1 cache hit fast path
|
||||||
#vbar doesn't support CPU weights, but some custom nodes have weird paths
|
def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blocking):
|
||||||
#that might switch the layer to the CPU and expect it to work. We have to take
|
|
||||||
#a clone conservatively as we are mmapped and some SFT files are packed misaligned
|
|
||||||
#If you are a custom node author reading this, please move your layer to the GPU
|
|
||||||
#or declare your ModelPatcher as CPU in the first place.
|
|
||||||
if comfy.model_management.is_device_cpu(device):
|
|
||||||
materialize_meta_param(s, ["weight", "bias"])
|
|
||||||
weight = s.weight.to(dtype=dtype, copy=True)
|
|
||||||
if isinstance(weight, QuantizedTensor):
|
|
||||||
weight = weight.dequantize()
|
|
||||||
bias = None
|
|
||||||
if s.bias is not None:
|
|
||||||
bias = s.bias.to(dtype=bias_dtype, copy=True)
|
|
||||||
return weight, bias, (None, None, None)
|
|
||||||
|
|
||||||
offload_stream = None
|
offload_stream = None
|
||||||
xfer_dest = None
|
cast_buffer = None
|
||||||
|
cast_buffer_offset = 0
|
||||||
|
|
||||||
|
def ensure_offload_stream(module, required_size, check_largest):
|
||||||
|
nonlocal offload_stream
|
||||||
|
nonlocal cast_buffer
|
||||||
|
|
||||||
|
if offload_stream is None:
|
||||||
|
offload_stream = comfy.model_management.get_offload_stream(device)
|
||||||
|
if offload_stream is None or not check_largest or len(comfy_modules) != 1:
|
||||||
|
return
|
||||||
|
|
||||||
|
current_size = 0 if cast_buffer is None else cast_buffer.size()
|
||||||
|
if current_size < required_size and module is comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT[0]:
|
||||||
|
offload_stream = comfy.model_management.get_offload_stream(device)
|
||||||
|
cast_buffer = None
|
||||||
|
if required_size > comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT[1]:
|
||||||
|
comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT = (module, required_size)
|
||||||
|
|
||||||
|
def get_cast_buffer(buffer_size):
|
||||||
|
nonlocal offload_stream
|
||||||
|
nonlocal cast_buffer
|
||||||
|
nonlocal cast_buffer_offset
|
||||||
|
|
||||||
|
if buffer_size == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if offload_stream is None:
|
||||||
|
return torch.empty((buffer_size,), dtype=torch.uint8, device=device)
|
||||||
|
|
||||||
|
cast_buffer = comfy.model_management.get_aimdo_cast_buffer(offload_stream, device)
|
||||||
|
buffer = comfy_aimdo.torch.aimdo_to_tensor(cast_buffer.get(buffer_size, cast_buffer_offset), device)
|
||||||
|
cast_buffer_offset += buffer_size
|
||||||
|
return buffer
|
||||||
|
|
||||||
|
for s in comfy_modules:
|
||||||
|
signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
|
||||||
|
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
|
||||||
|
prefetch = {
|
||||||
|
"signature": signature,
|
||||||
|
"resident": resident,
|
||||||
|
}
|
||||||
|
|
||||||
signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
|
|
||||||
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
|
|
||||||
if signature is not None:
|
|
||||||
if resident:
|
if resident:
|
||||||
weight = s._v_weight
|
s._prefetch = prefetch
|
||||||
bias = s._v_bias
|
continue
|
||||||
else:
|
|
||||||
xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device)
|
|
||||||
|
|
||||||
if not resident:
|
|
||||||
materialize_meta_param(s, ["weight", "bias"])
|
materialize_meta_param(s, ["weight", "bias"])
|
||||||
|
xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device) if signature is not None else None
|
||||||
cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ])
|
cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ])
|
||||||
cast_dest = None
|
cast_dest = None
|
||||||
|
needs_cast = False
|
||||||
|
|
||||||
xfer_source = [ s.weight, s.bias ]
|
xfer_source = [ s.weight, s.bias ]
|
||||||
|
|
||||||
@ -129,22 +152,15 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
|
|||||||
if data is None:
|
if data is None:
|
||||||
continue
|
continue
|
||||||
if data.dtype != geometry.dtype:
|
if data.dtype != geometry.dtype:
|
||||||
|
needs_cast = True
|
||||||
cast_dest = xfer_dest
|
cast_dest = xfer_dest
|
||||||
if cast_dest is None:
|
|
||||||
cast_dest = torch.empty((comfy.memory_management.vram_aligned_size(cast_geometry),), dtype=torch.uint8, device=device)
|
|
||||||
xfer_dest = None
|
xfer_dest = None
|
||||||
break
|
break
|
||||||
|
|
||||||
dest_size = comfy.memory_management.vram_aligned_size(xfer_source)
|
dest_size = comfy.memory_management.vram_aligned_size(xfer_source)
|
||||||
offload_stream = comfy.model_management.get_offload_stream(device)
|
ensure_offload_stream(s, dest_size if xfer_dest is None else 0, True)
|
||||||
if xfer_dest is None and offload_stream is not None:
|
|
||||||
xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s)
|
|
||||||
if xfer_dest is None:
|
|
||||||
offload_stream = comfy.model_management.get_offload_stream(device)
|
|
||||||
xfer_dest = comfy.model_management.get_cast_buffer(offload_stream, device, dest_size, s)
|
|
||||||
if xfer_dest is None:
|
if xfer_dest is None:
|
||||||
xfer_dest = torch.empty((dest_size,), dtype=torch.uint8, device=device)
|
xfer_dest = get_cast_buffer(dest_size)
|
||||||
offload_stream = None
|
|
||||||
|
|
||||||
if signature is None and pin is None:
|
if signature is None and pin is None:
|
||||||
comfy.pinned_memory.pin_memory(s)
|
comfy.pinned_memory.pin_memory(s)
|
||||||
@ -157,27 +173,54 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
|
|||||||
xfer_source = [ pin ]
|
xfer_source = [ pin ]
|
||||||
#send it over
|
#send it over
|
||||||
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream)
|
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream)
|
||||||
comfy.model_management.sync_stream(device, offload_stream)
|
|
||||||
|
|
||||||
if cast_dest is not None:
|
for param_key in ("weight", "bias"):
|
||||||
|
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
|
||||||
|
if lowvram_fn is not None:
|
||||||
|
ensure_offload_stream(s, cast_buffer_offset, False)
|
||||||
|
lowvram_fn.prepare(lambda size: get_cast_buffer(size), offload_stream)
|
||||||
|
|
||||||
|
prefetch["xfer_dest"] = xfer_dest
|
||||||
|
prefetch["cast_dest"] = cast_dest
|
||||||
|
prefetch["cast_geometry"] = cast_geometry
|
||||||
|
prefetch["needs_cast"] = needs_cast
|
||||||
|
s._prefetch = prefetch
|
||||||
|
|
||||||
|
return offload_stream
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_cast_module_with_vbar(s, dtype, device, bias_dtype, compute_dtype, want_requant):
|
||||||
|
|
||||||
|
prefetch = getattr(s, "_prefetch", None)
|
||||||
|
|
||||||
|
if prefetch["resident"]:
|
||||||
|
weight = s._v_weight
|
||||||
|
bias = s._v_bias
|
||||||
|
else:
|
||||||
|
xfer_dest = prefetch["xfer_dest"]
|
||||||
|
if prefetch["needs_cast"]:
|
||||||
|
cast_dest = prefetch["cast_dest"] if prefetch["cast_dest"] is not None else torch.empty((comfy.memory_management.vram_aligned_size(prefetch["cast_geometry"]),), dtype=torch.uint8, device=device)
|
||||||
for pre_cast, post_cast in zip(comfy.memory_management.interpret_gathered_like([s.weight, s.bias ], xfer_dest),
|
for pre_cast, post_cast in zip(comfy.memory_management.interpret_gathered_like([s.weight, s.bias ], xfer_dest),
|
||||||
comfy.memory_management.interpret_gathered_like(cast_geometry, cast_dest)):
|
comfy.memory_management.interpret_gathered_like(prefetch["cast_geometry"], cast_dest)):
|
||||||
if post_cast is not None:
|
if post_cast is not None:
|
||||||
post_cast.copy_(pre_cast)
|
post_cast.copy_(pre_cast)
|
||||||
xfer_dest = cast_dest
|
xfer_dest = cast_dest
|
||||||
|
|
||||||
params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest)
|
params = comfy.memory_management.interpret_gathered_like(prefetch["cast_geometry"], xfer_dest)
|
||||||
weight = params[0]
|
weight = params[0]
|
||||||
bias = params[1]
|
bias = params[1]
|
||||||
if signature is not None:
|
if prefetch["signature"] is not None:
|
||||||
s._v_weight = weight
|
s._v_weight = weight
|
||||||
s._v_bias = bias
|
s._v_bias = bias
|
||||||
s._v_signature=signature
|
s._v_signature = prefetch["signature"]
|
||||||
|
|
||||||
def post_cast(s, param_key, x, dtype, resident, update_weight):
|
def post_cast(s, param_key, x, dtype, resident, update_weight):
|
||||||
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
|
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
|
||||||
fns = getattr(s, param_key + "_function", [])
|
fns = getattr(s, param_key + "_function", [])
|
||||||
|
|
||||||
|
if x is None:
|
||||||
|
return None
|
||||||
|
|
||||||
orig = x
|
orig = x
|
||||||
|
|
||||||
def to_dequant(tensor, dtype):
|
def to_dequant(tensor, dtype):
|
||||||
@ -205,14 +248,12 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
|
|||||||
x = f(x)
|
x = f(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
update_weight = signature is not None
|
update_weight = prefetch["signature"] is not None
|
||||||
|
weight = post_cast(s, "weight", weight, dtype, prefetch["resident"], update_weight)
|
||||||
|
if bias is not None:
|
||||||
|
bias = post_cast(s, "bias", bias, bias_dtype, prefetch["resident"], update_weight)
|
||||||
|
|
||||||
weight = post_cast(s, "weight", weight, dtype, resident, update_weight)
|
return weight, bias
|
||||||
if s.bias is not None:
|
|
||||||
bias = post_cast(s, "bias", bias, bias_dtype, resident, update_weight)
|
|
||||||
|
|
||||||
#FIXME: weird offload return protocol
|
|
||||||
return weight, bias, (offload_stream, device if signature is not None else None, None)
|
|
||||||
|
|
||||||
|
|
||||||
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None, want_requant=False):
|
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None, want_requant=False):
|
||||||
@ -230,10 +271,46 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
|||||||
if device is None:
|
if device is None:
|
||||||
device = input.device
|
device = input.device
|
||||||
|
|
||||||
|
def format_return(result, offloadable):
|
||||||
|
weight, bias, offload_stream = result
|
||||||
|
return (weight, bias, offload_stream) if offloadable else (weight, bias)
|
||||||
|
|
||||||
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
||||||
|
|
||||||
if hasattr(s, "_v"):
|
if hasattr(s, "_v"):
|
||||||
return cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant)
|
|
||||||
|
#vbar doesn't support CPU weights, but some custom nodes have weird paths
|
||||||
|
#that might switch the layer to the CPU and expect it to work. We have to take
|
||||||
|
#a clone conservatively as we are mmapped and some SFT files are packed misaligned
|
||||||
|
#If you are a custom node author reading this, please move your layer to the GPU
|
||||||
|
#or declare your ModelPatcher as CPU in the first place.
|
||||||
|
if comfy.model_management.is_device_cpu(device):
|
||||||
|
materialize_meta_param(s, ["weight", "bias"])
|
||||||
|
weight = s.weight.to(dtype=dtype, copy=True)
|
||||||
|
if isinstance(weight, QuantizedTensor):
|
||||||
|
weight = weight.dequantize()
|
||||||
|
bias = s.bias.to(dtype=bias_dtype, copy=True) if s.bias is not None else None
|
||||||
|
return format_return((weight, bias, (None, None, None)), offloadable)
|
||||||
|
|
||||||
|
prefetched = hasattr(s, "_prefetch")
|
||||||
|
offload_stream = None
|
||||||
|
offload_device = None
|
||||||
|
if not prefetched:
|
||||||
|
offload_stream = cast_modules_with_vbar([s], dtype, device, bias_dtype, non_blocking)
|
||||||
|
comfy.model_management.sync_stream(device, offload_stream)
|
||||||
|
|
||||||
|
weight, bias = resolve_cast_module_with_vbar(s, dtype, device, bias_dtype, compute_dtype, want_requant)
|
||||||
|
|
||||||
|
if not prefetched:
|
||||||
|
if getattr(s, "_prefetch")["signature"] is not None:
|
||||||
|
offload_device = device
|
||||||
|
for param_key in ("weight", "bias"):
|
||||||
|
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
|
||||||
|
if lowvram_fn is not None:
|
||||||
|
lowvram_fn.clear_prepared()
|
||||||
|
delattr(s, "_prefetch")
|
||||||
|
return format_return((weight, bias, (offload_stream, offload_device, None)), offloadable)
|
||||||
|
|
||||||
|
|
||||||
if offloadable and (device != s.weight.device or
|
if offloadable and (device != s.weight.device or
|
||||||
(s.bias is not None and device != s.bias.device)):
|
(s.bias is not None and device != s.bias.device)):
|
||||||
@ -280,11 +357,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
|||||||
for f in s.weight_function:
|
for f in s.weight_function:
|
||||||
weight = f(weight)
|
weight = f(weight)
|
||||||
|
|
||||||
if offloadable:
|
return format_return((weight, bias, (offload_stream, weight_a, bias_a)), offloadable)
|
||||||
return weight, bias, (offload_stream, weight_a, bias_a)
|
|
||||||
else:
|
|
||||||
#Legacy function signature
|
|
||||||
return weight, bias
|
|
||||||
|
|
||||||
|
|
||||||
def uncast_bias_weight(s, weight, bias, offload_stream):
|
def uncast_bias_weight(s, weight, bias, offload_stream):
|
||||||
|
|||||||
@ -15,6 +15,7 @@ import torch
|
|||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
import comfy.memory_management
|
import comfy.memory_management
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
import comfy.model_prefetch
|
||||||
import comfy_aimdo.model_vbar
|
import comfy_aimdo.model_vbar
|
||||||
|
|
||||||
from latent_preview import set_preview_method
|
from latent_preview import set_preview_method
|
||||||
@ -537,6 +538,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
if args.verbose == "DEBUG":
|
if args.verbose == "DEBUG":
|
||||||
comfy_aimdo.control.analyze()
|
comfy_aimdo.control.analyze()
|
||||||
comfy.model_management.reset_cast_buffers()
|
comfy.model_management.reset_cast_buffers()
|
||||||
|
comfy.model_prefetch.cleanup_prefetch_queues()
|
||||||
comfy_aimdo.model_vbar.vbars_reset_watermark_limits()
|
comfy_aimdo.model_vbar.vbars_reset_watermark_limits()
|
||||||
|
|
||||||
if has_pending_tasks:
|
if has_pending_tasks:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user