ops: split up prefetch from weight prep block prefetching API

Split up the casting and weight formating/lora stuff in prep for
arbitrary prefetch support.
This commit is contained in:
Rattus 2026-04-27 22:52:07 +10:00
parent 132c9f3ac6
commit 0e93c88c67

View File

@ -86,27 +86,29 @@ def materialize_meta_param(s, param_keys):
setattr(s, param_key, torch.nn.Parameter(torch.zeros(param.shape, dtype=param.dtype), requires_grad=param.requires_grad))
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant):
#plan = []
#Some sort of loop here like what you did
#for module in comfy_modules: ...
# FIXME: add n=1 cache hit fast path
def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blocking):
offload_stream = None
xfer_dest = None
cast_buffer = None
cast_buffer_offset = 0
for s in comfy_modules:
signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
prefetch = {
"signature": signature,
"resident": resident,
}
signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
if signature is not None:
if resident:
weight = s._v_weight
bias = s._v_bias
else:
xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device)
s._prefetch = prefetch
continue
if not resident:
materialize_meta_param(s, ["weight", "bias"])
xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device) if signature is not None else None
cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ])
cast_dest = None
needs_cast = False
xfer_source = [ s.weight, s.bias ]
@ -118,25 +120,29 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
if data is None:
continue
if data.dtype != geometry.dtype:
needs_cast = True
cast_dest = xfer_dest
if cast_dest is None:
cast_dest = torch.empty((comfy.memory_management.vram_aligned_size(cast_geometry),), dtype=torch.uint8, device=device)
xfer_dest = None
break
dest_size = comfy.memory_management.vram_aligned_size(xfer_source)
offload_stream = comfy.model_management.get_offload_stream(device)
if xfer_dest is None and offload_stream is not None:
cast_buffer = comfy.model_management.get_aimdo_cast_buffer(offload_stream, device)
if offload_stream is None:
offload_stream = comfy.model_management.get_offload_stream(device)
if xfer_dest is None and offload_stream is not None and cast_buffer is None:
cast_buffer = comfy.model_management.get_aimdo_cast_buffer(offload_stream, device)
if len(comfy_modules) == 1:
if cast_buffer.size() < dest_size and s is comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT[0]:
offload_stream = comfy.model_management.get_offload_stream(device)
cast_buffer = comfy.model_management.get_aimdo_cast_buffer(offload_stream, device)
xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(cast_buffer.get(dest_size), device)
if dest_size > comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT[1]:
comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT = (s, dest_size)
if xfer_dest is None:
xfer_dest = torch.empty((dest_size,), dtype=torch.uint8, device=device)
offload_stream = None
if cast_buffer is not None:
xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(cast_buffer.get(dest_size, cast_buffer_offset), device)
cast_buffer_offset += dest_size
else:
xfer_dest = torch.empty((dest_size,), dtype=torch.uint8, device=device)
offload_stream = None
if signature is None and pin is None:
comfy.pinned_memory.pin_memory(s)
@ -149,29 +155,45 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
xfer_source = [ pin ]
#send it over
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream)
#attach prefetch info to the module inside the loop ..
prefetch["xfer_dest"] = xfer_dest
prefetch["cast_dest"] = cast_dest
prefetch["cast_geometry"] = cast_geometry
prefetch["needs_cast"] = needs_cast
s._prefetch = prefetch
#this sync is conceptually the last thing this function does - after the loop
comfy.model_management.sync_stream(device, offload_stream)
return offload_stream
#all compute stuff need to be deferred to the new second phase
if cast_dest is not None:
def phase_2(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant):
del non_blocking
prefetch = getattr(s, "_prefetch", None)
if prefetch is None:
raise RuntimeError("phase_2 called without a VBAR prefetch state")
if prefetch["resident"]:
weight = s._v_weight
bias = s._v_bias
else:
xfer_dest = prefetch["xfer_dest"]
if prefetch["needs_cast"]:
cast_dest = prefetch["cast_dest"] if prefetch["cast_dest"] is not None else torch.empty((comfy.memory_management.vram_aligned_size(prefetch["cast_geometry"]),), dtype=torch.uint8, device=device)
for pre_cast, post_cast in zip(comfy.memory_management.interpret_gathered_like([s.weight, s.bias ], xfer_dest),
comfy.memory_management.interpret_gathered_like(cast_geometry, cast_dest)):
comfy.memory_management.interpret_gathered_like(prefetch["cast_geometry"], cast_dest)):
if post_cast is not None:
post_cast.copy_(pre_cast)
xfer_dest = cast_dest
params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest)
params = comfy.memory_management.interpret_gathered_like(prefetch["cast_geometry"], xfer_dest)
weight = params[0]
bias = params[1]
if signature is not None:
if prefetch["signature"] is not None:
s._v_weight = weight
s._v_bias = bias
s._v_signature=signature
s._v_signature = prefetch["signature"]
#factor this our like you did before.
def post_cast(s, param_key, x, dtype, resident, update_weight):
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
fns = getattr(s, param_key + "_function", [])
@ -203,14 +225,13 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
x = f(x)
return x
update_weight = signature is not None
weight = post_cast(s, "weight", weight, dtype, resident, update_weight)
update_weight = prefetch["signature"] is not None
weight = post_cast(s, "weight", weight, dtype, prefetch["resident"], update_weight)
bias = None
if s.bias is not None:
bias = post_cast(s, "bias", bias, bias_dtype, resident, update_weight)
bias = post_cast(s, "bias", bias, bias_dtype, prefetch["resident"], update_weight)
#FIXME: weird offload return protocol
return weight, bias, (offload_stream, device if signature is not None else None, None)
return weight, bias
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None, want_requant=False):
@ -228,6 +249,10 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
if device is None:
device = input.device
def format_return(result, offloadable):
weight, bias, offload_stream = result
return (weight, bias, offload_stream) if offloadable else (weight, bias)
non_blocking = comfy.model_management.device_supports_non_blocking(device)
if hasattr(s, "_v"):
@ -243,13 +268,23 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
if isinstance(weight, QuantizedTensor):
weight = weight.dequantize()
bias = s.bias.to(dtype=bias_dtype, copy=True) if s.bias is not None else None
return (weight, bias, (None, None, None)) if offloadable else (weight, bias)
return format_return((weight, bias, (None, None, None)), offloadable)
prefetched = hasattr(s, "_prefetch")
offload_stream = None
offload_device = None
if not prefetched:
offload_stream = cast_modules_with_vbar([s], dtype, device, bias_dtype, non_blocking)
comfy.model_management.sync_stream(device, offload_stream)
weight, bias = phase_2(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant)
if not prefetched:
if getattr(s, "_prefetch")["signature"] is not None:
offload_device = device
delattr(s, "_prefetch")
return format_return((weight, bias, (offload_stream, offload_device, None)), offloadable)
#check for a prefetch result here. Something like:
#if not prefetch:
#cast_modules([s], ...)
#this is the phase 2 call like you made before ...
return phase_2(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant)
if offloadable and (device != s.weight.device or
(s.bias is not None and device != s.bias.device)):
@ -296,11 +331,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
for f in s.weight_function:
weight = f(weight)
if offloadable:
return weight, bias, (offload_stream, weight_a, bias_a)
else:
#Legacy function signature
return weight, bias
return format_return((weight, bias, (offload_stream, weight_a, bias_a)), offloadable)
def uncast_bias_weight(s, weight, bias, offload_stream):