ops: split up prefetch from weight prep block prefetching API

Split up the casting and weight formating/lora stuff in prep for
arbitrary prefetch support.
This commit is contained in:
Rattus 2026-04-27 22:52:07 +10:00
parent 132c9f3ac6
commit 0e93c88c67

View File

@ -86,27 +86,29 @@ def materialize_meta_param(s, param_keys):
setattr(s, param_key, torch.nn.Parameter(torch.zeros(param.shape, dtype=param.dtype), requires_grad=param.requires_grad)) setattr(s, param_key, torch.nn.Parameter(torch.zeros(param.shape, dtype=param.dtype), requires_grad=param.requires_grad))
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant): # FIXME: add n=1 cache hit fast path
#plan = [] def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blocking):
#Some sort of loop here like what you did
#for module in comfy_modules: ...
offload_stream = None offload_stream = None
xfer_dest = None cast_buffer = None
cast_buffer_offset = 0
for s in comfy_modules:
signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
prefetch = {
"signature": signature,
"resident": resident,
}
signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
if signature is not None:
if resident: if resident:
weight = s._v_weight s._prefetch = prefetch
bias = s._v_bias continue
else:
xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device)
if not resident:
materialize_meta_param(s, ["weight", "bias"]) materialize_meta_param(s, ["weight", "bias"])
xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device) if signature is not None else None
cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ]) cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ])
cast_dest = None cast_dest = None
needs_cast = False
xfer_source = [ s.weight, s.bias ] xfer_source = [ s.weight, s.bias ]
@ -118,25 +120,29 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
if data is None: if data is None:
continue continue
if data.dtype != geometry.dtype: if data.dtype != geometry.dtype:
needs_cast = True
cast_dest = xfer_dest cast_dest = xfer_dest
if cast_dest is None:
cast_dest = torch.empty((comfy.memory_management.vram_aligned_size(cast_geometry),), dtype=torch.uint8, device=device)
xfer_dest = None xfer_dest = None
break break
dest_size = comfy.memory_management.vram_aligned_size(xfer_source) dest_size = comfy.memory_management.vram_aligned_size(xfer_source)
offload_stream = comfy.model_management.get_offload_stream(device) if offload_stream is None:
if xfer_dest is None and offload_stream is not None: offload_stream = comfy.model_management.get_offload_stream(device)
cast_buffer = comfy.model_management.get_aimdo_cast_buffer(offload_stream, device) if xfer_dest is None and offload_stream is not None and cast_buffer is None:
cast_buffer = comfy.model_management.get_aimdo_cast_buffer(offload_stream, device)
if len(comfy_modules) == 1:
if cast_buffer.size() < dest_size and s is comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT[0]: if cast_buffer.size() < dest_size and s is comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT[0]:
offload_stream = comfy.model_management.get_offload_stream(device) offload_stream = comfy.model_management.get_offload_stream(device)
cast_buffer = comfy.model_management.get_aimdo_cast_buffer(offload_stream, device) cast_buffer = comfy.model_management.get_aimdo_cast_buffer(offload_stream, device)
xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(cast_buffer.get(dest_size), device)
if dest_size > comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT[1]: if dest_size > comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT[1]:
comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT = (s, dest_size) comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT = (s, dest_size)
if xfer_dest is None: if xfer_dest is None:
xfer_dest = torch.empty((dest_size,), dtype=torch.uint8, device=device) if cast_buffer is not None:
offload_stream = None xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(cast_buffer.get(dest_size, cast_buffer_offset), device)
cast_buffer_offset += dest_size
else:
xfer_dest = torch.empty((dest_size,), dtype=torch.uint8, device=device)
offload_stream = None
if signature is None and pin is None: if signature is None and pin is None:
comfy.pinned_memory.pin_memory(s) comfy.pinned_memory.pin_memory(s)
@ -149,29 +155,45 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
xfer_source = [ pin ] xfer_source = [ pin ]
#send it over #send it over
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream) comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream)
#attach prefetch info to the module inside the loop .. prefetch["xfer_dest"] = xfer_dest
prefetch["cast_dest"] = cast_dest
prefetch["cast_geometry"] = cast_geometry
prefetch["needs_cast"] = needs_cast
s._prefetch = prefetch
#this sync is conceptually the last thing this function does - after the loop return offload_stream
comfy.model_management.sync_stream(device, offload_stream)
#all compute stuff need to be deferred to the new second phase
if cast_dest is not None:
def phase_2(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant):
del non_blocking
prefetch = getattr(s, "_prefetch", None)
if prefetch is None:
raise RuntimeError("phase_2 called without a VBAR prefetch state")
if prefetch["resident"]:
weight = s._v_weight
bias = s._v_bias
else:
xfer_dest = prefetch["xfer_dest"]
if prefetch["needs_cast"]:
cast_dest = prefetch["cast_dest"] if prefetch["cast_dest"] is not None else torch.empty((comfy.memory_management.vram_aligned_size(prefetch["cast_geometry"]),), dtype=torch.uint8, device=device)
for pre_cast, post_cast in zip(comfy.memory_management.interpret_gathered_like([s.weight, s.bias ], xfer_dest), for pre_cast, post_cast in zip(comfy.memory_management.interpret_gathered_like([s.weight, s.bias ], xfer_dest),
comfy.memory_management.interpret_gathered_like(cast_geometry, cast_dest)): comfy.memory_management.interpret_gathered_like(prefetch["cast_geometry"], cast_dest)):
if post_cast is not None: if post_cast is not None:
post_cast.copy_(pre_cast) post_cast.copy_(pre_cast)
xfer_dest = cast_dest xfer_dest = cast_dest
params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest) params = comfy.memory_management.interpret_gathered_like(prefetch["cast_geometry"], xfer_dest)
weight = params[0] weight = params[0]
bias = params[1] bias = params[1]
if signature is not None: if prefetch["signature"] is not None:
s._v_weight = weight s._v_weight = weight
s._v_bias = bias s._v_bias = bias
s._v_signature=signature s._v_signature = prefetch["signature"]
#factor this our like you did before.
def post_cast(s, param_key, x, dtype, resident, update_weight): def post_cast(s, param_key, x, dtype, resident, update_weight):
lowvram_fn = getattr(s, param_key + "_lowvram_function", None) lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
fns = getattr(s, param_key + "_function", []) fns = getattr(s, param_key + "_function", [])
@ -203,14 +225,13 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
x = f(x) x = f(x)
return x return x
update_weight = signature is not None update_weight = prefetch["signature"] is not None
weight = post_cast(s, "weight", weight, dtype, prefetch["resident"], update_weight)
weight = post_cast(s, "weight", weight, dtype, resident, update_weight) bias = None
if s.bias is not None: if s.bias is not None:
bias = post_cast(s, "bias", bias, bias_dtype, resident, update_weight) bias = post_cast(s, "bias", bias, bias_dtype, prefetch["resident"], update_weight)
#FIXME: weird offload return protocol return weight, bias
return weight, bias, (offload_stream, device if signature is not None else None, None)
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None, want_requant=False): def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None, want_requant=False):
@ -228,6 +249,10 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
if device is None: if device is None:
device = input.device device = input.device
def format_return(result, offloadable):
weight, bias, offload_stream = result
return (weight, bias, offload_stream) if offloadable else (weight, bias)
non_blocking = comfy.model_management.device_supports_non_blocking(device) non_blocking = comfy.model_management.device_supports_non_blocking(device)
if hasattr(s, "_v"): if hasattr(s, "_v"):
@ -243,13 +268,23 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
if isinstance(weight, QuantizedTensor): if isinstance(weight, QuantizedTensor):
weight = weight.dequantize() weight = weight.dequantize()
bias = s.bias.to(dtype=bias_dtype, copy=True) if s.bias is not None else None bias = s.bias.to(dtype=bias_dtype, copy=True) if s.bias is not None else None
return (weight, bias, (None, None, None)) if offloadable else (weight, bias) return format_return((weight, bias, (None, None, None)), offloadable)
prefetched = hasattr(s, "_prefetch")
offload_stream = None
offload_device = None
if not prefetched:
offload_stream = cast_modules_with_vbar([s], dtype, device, bias_dtype, non_blocking)
comfy.model_management.sync_stream(device, offload_stream)
weight, bias = phase_2(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant)
if not prefetched:
if getattr(s, "_prefetch")["signature"] is not None:
offload_device = device
delattr(s, "_prefetch")
return format_return((weight, bias, (offload_stream, offload_device, None)), offloadable)
#check for a prefetch result here. Something like:
#if not prefetch:
#cast_modules([s], ...)
#this is the phase 2 call like you made before ...
return phase_2(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant)
if offloadable and (device != s.weight.device or if offloadable and (device != s.weight.device or
(s.bias is not None and device != s.bias.device)): (s.bias is not None and device != s.bias.device)):
@ -296,11 +331,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
for f in s.weight_function: for f in s.weight_function:
weight = f(weight) weight = f(weight)
if offloadable: return format_return((weight, bias, (offload_stream, weight_a, bias_a)), offloadable)
return weight, bias, (offload_stream, weight_a, bias_a)
else:
#Legacy function signature
return weight, bias
def uncast_bias_weight(s, weight, bias, offload_stream): def uncast_bias_weight(s, weight, bias, offload_stream):