diff --git a/comfy/ops.py b/comfy/ops.py index 6a4189791..6baad439f 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -86,27 +86,29 @@ def materialize_meta_param(s, param_keys): setattr(s, param_key, torch.nn.Parameter(torch.zeros(param.shape, dtype=param.dtype), requires_grad=param.requires_grad)) -def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant): - #plan = [] - #Some sort of loop here like what you did - #for module in comfy_modules: ... - +# FIXME: add n=1 cache hit fast path +def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blocking): offload_stream = None - xfer_dest = None + cast_buffer = None + cast_buffer_offset = 0 + + for s in comfy_modules: + signature = comfy_aimdo.model_vbar.vbar_fault(s._v) + resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature) + prefetch = { + "signature": signature, + "resident": resident, + } - signature = comfy_aimdo.model_vbar.vbar_fault(s._v) - resident = comfy_aimdo.model_vbar.vbar_signature_compare(signature, s._v_signature) - if signature is not None: if resident: - weight = s._v_weight - bias = s._v_bias - else: - xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device) + s._prefetch = prefetch + continue - if not resident: materialize_meta_param(s, ["weight", "bias"]) + xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(s._v, device) if signature is not None else None cast_geometry = comfy.memory_management.tensors_to_geometries([ s.weight, s.bias ]) cast_dest = None + needs_cast = False xfer_source = [ s.weight, s.bias ] @@ -118,25 +120,29 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu if data is None: continue if data.dtype != geometry.dtype: + needs_cast = True cast_dest = xfer_dest - if cast_dest is None: - cast_dest = torch.empty((comfy.memory_management.vram_aligned_size(cast_geometry),), dtype=torch.uint8, device=device) xfer_dest = None break dest_size = comfy.memory_management.vram_aligned_size(xfer_source) - offload_stream = comfy.model_management.get_offload_stream(device) - if xfer_dest is None and offload_stream is not None: - cast_buffer = comfy.model_management.get_aimdo_cast_buffer(offload_stream, device) + if offload_stream is None: + offload_stream = comfy.model_management.get_offload_stream(device) + if xfer_dest is None and offload_stream is not None and cast_buffer is None: + cast_buffer = comfy.model_management.get_aimdo_cast_buffer(offload_stream, device) + if len(comfy_modules) == 1: if cast_buffer.size() < dest_size and s is comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT[0]: offload_stream = comfy.model_management.get_offload_stream(device) cast_buffer = comfy.model_management.get_aimdo_cast_buffer(offload_stream, device) - xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(cast_buffer.get(dest_size), device) if dest_size > comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT[1]: comfy.model_management.LARGEST_AIMDO_CASTED_WEIGHT = (s, dest_size) if xfer_dest is None: - xfer_dest = torch.empty((dest_size,), dtype=torch.uint8, device=device) - offload_stream = None + if cast_buffer is not None: + xfer_dest = comfy_aimdo.torch.aimdo_to_tensor(cast_buffer.get(dest_size, cast_buffer_offset), device) + cast_buffer_offset += dest_size + else: + xfer_dest = torch.empty((dest_size,), dtype=torch.uint8, device=device) + offload_stream = None if signature is None and pin is None: comfy.pinned_memory.pin_memory(s) @@ -149,29 +155,45 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu xfer_source = [ pin ] #send it over comfy.model_management.cast_to_gathered(xfer_source, xfer_dest, non_blocking=non_blocking, stream=offload_stream) - #attach prefetch info to the module inside the loop .. + prefetch["xfer_dest"] = xfer_dest + prefetch["cast_dest"] = cast_dest + prefetch["cast_geometry"] = cast_geometry + prefetch["needs_cast"] = needs_cast + s._prefetch = prefetch - #this sync is conceptually the last thing this function does - after the loop - comfy.model_management.sync_stream(device, offload_stream) + return offload_stream - #all compute stuff need to be deferred to the new second phase - if cast_dest is not None: + + +def phase_2(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant): + del non_blocking + + prefetch = getattr(s, "_prefetch", None) + if prefetch is None: + raise RuntimeError("phase_2 called without a VBAR prefetch state") + + if prefetch["resident"]: + weight = s._v_weight + bias = s._v_bias + else: + xfer_dest = prefetch["xfer_dest"] + if prefetch["needs_cast"]: + cast_dest = prefetch["cast_dest"] if prefetch["cast_dest"] is not None else torch.empty((comfy.memory_management.vram_aligned_size(prefetch["cast_geometry"]),), dtype=torch.uint8, device=device) for pre_cast, post_cast in zip(comfy.memory_management.interpret_gathered_like([s.weight, s.bias ], xfer_dest), - comfy.memory_management.interpret_gathered_like(cast_geometry, cast_dest)): + comfy.memory_management.interpret_gathered_like(prefetch["cast_geometry"], cast_dest)): if post_cast is not None: post_cast.copy_(pre_cast) xfer_dest = cast_dest - params = comfy.memory_management.interpret_gathered_like(cast_geometry, xfer_dest) + params = comfy.memory_management.interpret_gathered_like(prefetch["cast_geometry"], xfer_dest) weight = params[0] bias = params[1] - if signature is not None: + if prefetch["signature"] is not None: s._v_weight = weight s._v_bias = bias - s._v_signature=signature + s._v_signature = prefetch["signature"] - #factor this our like you did before. def post_cast(s, param_key, x, dtype, resident, update_weight): lowvram_fn = getattr(s, param_key + "_lowvram_function", None) fns = getattr(s, param_key + "_function", []) @@ -203,14 +225,13 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu x = f(x) return x - update_weight = signature is not None - - weight = post_cast(s, "weight", weight, dtype, resident, update_weight) + update_weight = prefetch["signature"] is not None + weight = post_cast(s, "weight", weight, dtype, prefetch["resident"], update_weight) + bias = None if s.bias is not None: - bias = post_cast(s, "bias", bias, bias_dtype, resident, update_weight) + bias = post_cast(s, "bias", bias, bias_dtype, prefetch["resident"], update_weight) - #FIXME: weird offload return protocol - return weight, bias, (offload_stream, device if signature is not None else None, None) + return weight, bias def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None, want_requant=False): @@ -228,6 +249,10 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of if device is None: device = input.device + def format_return(result, offloadable): + weight, bias, offload_stream = result + return (weight, bias, offload_stream) if offloadable else (weight, bias) + non_blocking = comfy.model_management.device_supports_non_blocking(device) if hasattr(s, "_v"): @@ -243,13 +268,23 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of if isinstance(weight, QuantizedTensor): weight = weight.dequantize() bias = s.bias.to(dtype=bias_dtype, copy=True) if s.bias is not None else None - return (weight, bias, (None, None, None)) if offloadable else (weight, bias) + return format_return((weight, bias, (None, None, None)), offloadable) + + prefetched = hasattr(s, "_prefetch") + offload_stream = None + offload_device = None + if not prefetched: + offload_stream = cast_modules_with_vbar([s], dtype, device, bias_dtype, non_blocking) + comfy.model_management.sync_stream(device, offload_stream) + + weight, bias = phase_2(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant) + + if not prefetched: + if getattr(s, "_prefetch")["signature"] is not None: + offload_device = device + delattr(s, "_prefetch") + return format_return((weight, bias, (offload_stream, offload_device, None)), offloadable) - #check for a prefetch result here. Something like: - #if not prefetch: - #cast_modules([s], ...) - #this is the phase 2 call like you made before ... - return phase_2(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant) if offloadable and (device != s.weight.device or (s.bias is not None and device != s.bias.device)): @@ -296,11 +331,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of for f in s.weight_function: weight = f(weight) - if offloadable: - return weight, bias, (offload_stream, weight_a, bias_a) - else: - #Legacy function signature - return weight, bias + return format_return((weight, bias, (offload_stream, weight_a, bias_a)), offloadable) def uncast_bias_weight(s, weight, bias, offload_stream):