mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-26 17:07:25 +08:00
ops: split up prefetch from weight prep block prefetching API
Split up the casting and weight formating/lora stuff in prep for arbitrary prefetch support.
This commit is contained in:
parent
132c9f3ac6
commit
0e93c88c67
129
comfy/ops.py
129
comfy/ops.py
@ -86,27 +86,29 @@ def materialize_meta_param(s, param_keys):
|
|||||||
setattr(s, param_key, torch.nn.Parameter(torch.zeros(param.shape, dtype=param.dtype), requires_grad=param.requires_grad))
|
setattr(s, param_key, torch.nn.Parameter(torch.zeros(param.shape, dtype=param.dtype), requires_grad=param.requires_grad))
|
||||||
|
|
||||||
|
|
||||||
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant):
|
# FIXME: add n=1 cache hit fast path
|
||||||
#plan = []
|
def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blocking):
|
||||||
#Some sort of loop here like what you did
|
|
||||||
#for module in comfy_modules: ...
|
|
||||||
|
|
||||||
offload_stream = None
|
offload_stream = None
|
||||||
xfer_dest = None
|
cast_buffer = None
|
||||||
|
cast_buffer_offset = 0
|
||||||
|
|
||||||
|
for s in comfy_modules:
|
||||||
|
signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
|
||||||
|
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
|
||||||
|
prefetch = {
|
||||||
|
"signature": signature,
|
||||||
|
"resident": resident,
|
||||||
|
}
|
||||||
|
|
||||||
signature = comfy_aimdo.model_vbar.vbar_fault(s._v)
|
|
||||||
resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature)
|
|
||||||
if signature is not None:
|
|
||||||
if resident:
|
if resident:
|
||||||
weight = s._v_weight
|
s._prefetch = prefetch
|
||||||
bias = s._v_bias
|
continue
|
||||||
else:
|
|
||||||
xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device)
|
|
||||||
|
|
||||||
if not resident:
|
|
||||||
materialize_meta_param(s, ["weight", "bias"])
|
materialize_meta_param(s, ["weight", "bias"])
|
||||||
|
xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device) if signature is not None else None
|
||||||
cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ])
|
cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ])
|
||||||
cast_dest = None
|
cast_dest = None
|
||||||
|
needs_cast = False
|
||||||
|
|
||||||
xfer_source = [ s.weight, s.bias ]
|
xfer_source = [ s.weight, s.bias ]
|
||||||
|
|
||||||
@ -118,25 +120,29 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
|
|||||||
if data is None:
|
if data is None:
|
||||||
continue
|
continue
|
||||||
if data.dtype != geometry.dtype:
|
if data.dtype != geometry.dtype:
|
||||||
|
needs_cast = True
|
||||||
cast_dest = xfer_dest
|
cast_dest = xfer_dest
|
||||||
if cast_dest is None:
|
|
||||||
cast_dest = torch.empty((comfy.memory_management.vram_aligned_size(cast_geometry),), dtype=torch.uint8, device=device)
|
|
||||||
xfer_dest = None
|
xfer_dest = None
|
||||||
break
|
break
|
||||||
|
|
||||||
dest_size = comfy.memory_management.vram_aligned_size(xfer_source)
|
dest_size = comfy.memory_management.vram_aligned_size(xfer_source)
|
||||||
offload_stream = comfy.model_management.get_offload_stream(device)
|
if offload_stream is None:
|
||||||
if xfer_dest is None and offload_stream is not None:
|
offload_stream = comfy.model_management.get_offload_stream(device)
|
||||||
cast_buffer = comfy.model_management.get_aimdo_cast_buffer(offload_stream, device)
|
if xfer_dest is None and offload_stream is not None and cast_buffer is None:
|
||||||
|
cast_buffer = comfy.model_management.get_aimdo_cast_buffer(offload_stream, device)
|
||||||
|
if len(comfy_modules) == 1:
|
||||||
if cast_buffer.size() < dest_size and s is comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT[0]:
|
if cast_buffer.size() < dest_size and s is comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT[0]:
|
||||||
offload_stream = comfy.model_management.get_offload_stream(device)
|
offload_stream = comfy.model_management.get_offload_stream(device)
|
||||||
cast_buffer = comfy.model_management.get_aimdo_cast_buffer(offload_stream, device)
|
cast_buffer = comfy.model_management.get_aimdo_cast_buffer(offload_stream, device)
|
||||||
xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(cast_buffer.get(dest_size), device)
|
|
||||||
if dest_size > comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT[1]:
|
if dest_size > comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT[1]:
|
||||||
comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT = (s, dest_size)
|
comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT = (s, dest_size)
|
||||||
if xfer_dest is None:
|
if xfer_dest is None:
|
||||||
xfer_dest = torch.empty((dest_size,), dtype=torch.uint8, device=device)
|
if cast_buffer is not None:
|
||||||
offload_stream = None
|
xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(cast_buffer.get(dest_size, cast_buffer_offset), device)
|
||||||
|
cast_buffer_offset += dest_size
|
||||||
|
else:
|
||||||
|
xfer_dest = torch.empty((dest_size,), dtype=torch.uint8, device=device)
|
||||||
|
offload_stream = None
|
||||||
|
|
||||||
if signature is None and pin is None:
|
if signature is None and pin is None:
|
||||||
comfy.pinned_memory.pin_memory(s)
|
comfy.pinned_memory.pin_memory(s)
|
||||||
@ -149,29 +155,45 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
|
|||||||
xfer_source = [ pin ]
|
xfer_source = [ pin ]
|
||||||
#send it over
|
#send it over
|
||||||
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream)
|
comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream)
|
||||||
#attach prefetch info to the module inside the loop ..
|
prefetch["xfer_dest"] = xfer_dest
|
||||||
|
prefetch["cast_dest"] = cast_dest
|
||||||
|
prefetch["cast_geometry"] = cast_geometry
|
||||||
|
prefetch["needs_cast"] = needs_cast
|
||||||
|
s._prefetch = prefetch
|
||||||
|
|
||||||
#this sync is conceptually the last thing this function does - after the loop
|
return offload_stream
|
||||||
comfy.model_management.sync_stream(device, offload_stream)
|
|
||||||
|
|
||||||
|
|
||||||
#all compute stuff need to be deferred to the new second phase
|
|
||||||
if cast_dest is not None:
|
|
||||||
|
def phase_2(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant):
|
||||||
|
del non_blocking
|
||||||
|
|
||||||
|
prefetch = getattr(s, "_prefetch", None)
|
||||||
|
if prefetch is None:
|
||||||
|
raise RuntimeError("phase_2 called without a VBAR prefetch state")
|
||||||
|
|
||||||
|
if prefetch["resident"]:
|
||||||
|
weight = s._v_weight
|
||||||
|
bias = s._v_bias
|
||||||
|
else:
|
||||||
|
xfer_dest = prefetch["xfer_dest"]
|
||||||
|
if prefetch["needs_cast"]:
|
||||||
|
cast_dest = prefetch["cast_dest"] if prefetch["cast_dest"] is not None else torch.empty((comfy.memory_management.vram_aligned_size(prefetch["cast_geometry"]),), dtype=torch.uint8, device=device)
|
||||||
for pre_cast, post_cast in zip(comfy.memory_management.interpret_gathered_like([s.weight, s.bias ], xfer_dest),
|
for pre_cast, post_cast in zip(comfy.memory_management.interpret_gathered_like([s.weight, s.bias ], xfer_dest),
|
||||||
comfy.memory_management.interpret_gathered_like(cast_geometry, cast_dest)):
|
comfy.memory_management.interpret_gathered_like(prefetch["cast_geometry"], cast_dest)):
|
||||||
if post_cast is not None:
|
if post_cast is not None:
|
||||||
post_cast.copy_(pre_cast)
|
post_cast.copy_(pre_cast)
|
||||||
xfer_dest = cast_dest
|
xfer_dest = cast_dest
|
||||||
|
|
||||||
params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest)
|
params = comfy.memory_management.interpret_gathered_like(prefetch["cast_geometry"], xfer_dest)
|
||||||
weight = params[0]
|
weight = params[0]
|
||||||
bias = params[1]
|
bias = params[1]
|
||||||
if signature is not None:
|
if prefetch["signature"] is not None:
|
||||||
s._v_weight = weight
|
s._v_weight = weight
|
||||||
s._v_bias = bias
|
s._v_bias = bias
|
||||||
s._v_signature=signature
|
s._v_signature = prefetch["signature"]
|
||||||
|
|
||||||
#factor this our like you did before.
|
|
||||||
def post_cast(s, param_key, x, dtype, resident, update_weight):
|
def post_cast(s, param_key, x, dtype, resident, update_weight):
|
||||||
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
|
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
|
||||||
fns = getattr(s, param_key + "_function", [])
|
fns = getattr(s, param_key + "_function", [])
|
||||||
@ -203,14 +225,13 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
|
|||||||
x = f(x)
|
x = f(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
update_weight = signature is not None
|
update_weight = prefetch["signature"] is not None
|
||||||
|
weight = post_cast(s, "weight", weight, dtype, prefetch["resident"], update_weight)
|
||||||
weight = post_cast(s, "weight", weight, dtype, resident, update_weight)
|
bias = None
|
||||||
if s.bias is not None:
|
if s.bias is not None:
|
||||||
bias = post_cast(s, "bias", bias, bias_dtype, resident, update_weight)
|
bias = post_cast(s, "bias", bias, bias_dtype, prefetch["resident"], update_weight)
|
||||||
|
|
||||||
#FIXME: weird offload return protocol
|
return weight, bias
|
||||||
return weight, bias, (offload_stream, device if signature is not None else None, None)
|
|
||||||
|
|
||||||
|
|
||||||
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None, want_requant=False):
|
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None, want_requant=False):
|
||||||
@ -228,6 +249,10 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
|||||||
if device is None:
|
if device is None:
|
||||||
device = input.device
|
device = input.device
|
||||||
|
|
||||||
|
def format_return(result, offloadable):
|
||||||
|
weight, bias, offload_stream = result
|
||||||
|
return (weight, bias, offload_stream) if offloadable else (weight, bias)
|
||||||
|
|
||||||
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
||||||
|
|
||||||
if hasattr(s, "_v"):
|
if hasattr(s, "_v"):
|
||||||
@ -243,13 +268,23 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
|||||||
if isinstance(weight, QuantizedTensor):
|
if isinstance(weight, QuantizedTensor):
|
||||||
weight = weight.dequantize()
|
weight = weight.dequantize()
|
||||||
bias = s.bias.to(dtype=bias_dtype, copy=True) if s.bias is not None else None
|
bias = s.bias.to(dtype=bias_dtype, copy=True) if s.bias is not None else None
|
||||||
return (weight, bias, (None, None, None)) if offloadable else (weight, bias)
|
return format_return((weight, bias, (None, None, None)), offloadable)
|
||||||
|
|
||||||
|
prefetched = hasattr(s, "_prefetch")
|
||||||
|
offload_stream = None
|
||||||
|
offload_device = None
|
||||||
|
if not prefetched:
|
||||||
|
offload_stream = cast_modules_with_vbar([s], dtype, device, bias_dtype, non_blocking)
|
||||||
|
comfy.model_management.sync_stream(device, offload_stream)
|
||||||
|
|
||||||
|
weight, bias = phase_2(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant)
|
||||||
|
|
||||||
|
if not prefetched:
|
||||||
|
if getattr(s, "_prefetch")["signature"] is not None:
|
||||||
|
offload_device = device
|
||||||
|
delattr(s, "_prefetch")
|
||||||
|
return format_return((weight, bias, (offload_stream, offload_device, None)), offloadable)
|
||||||
|
|
||||||
#check for a prefetch result here. Something like:
|
|
||||||
#if not prefetch:
|
|
||||||
#cast_modules([s], ...)
|
|
||||||
#this is the phase 2 call like you made before ...
|
|
||||||
return phase_2(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant)
|
|
||||||
|
|
||||||
if offloadable and (device != s.weight.device or
|
if offloadable and (device != s.weight.device or
|
||||||
(s.bias is not None and device != s.bias.device)):
|
(s.bias is not None and device != s.bias.device)):
|
||||||
@ -296,11 +331,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
|||||||
for f in s.weight_function:
|
for f in s.weight_function:
|
||||||
weight = f(weight)
|
weight = f(weight)
|
||||||
|
|
||||||
if offloadable:
|
return format_return((weight, bias, (offload_stream, weight_a, bias_a)), offloadable)
|
||||||
return weight, bias, (offload_stream, weight_a, bias_a)
|
|
||||||
else:
|
|
||||||
#Legacy function signature
|
|
||||||
return weight, bias
|
|
||||||
|
|
||||||
|
|
||||||
def uncast_bias_weight(s, weight, bias, offload_stream):
|
def uncast_bias_weight(s, weight, bias, offload_stream):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user