mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-03 02:00:29 +08:00
dynamic_vram: respect argument cast dtypes in non-comfy weights (#12209)
This function has a dtype argument that allows the caller to set the dtype in the cast. TIL Some models override this on weight casts, which means its the highest priority. Priority scheme is: argument > model dtype > state dict dtype
This commit is contained in:
parent
361b9a82a3
commit
794d05bdb1
@ -1202,27 +1202,36 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str
|
||||
assert r is None
|
||||
assert stream is None
|
||||
|
||||
r = torch.empty_like(weight, dtype=weight._model_dtype, device=device)
|
||||
cast_geometry = comfy.memory_management.tensors_to_geometries([ weight ])
|
||||
|
||||
if dtype is None:
|
||||
dtype = weight._model_dtype
|
||||
|
||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||
|
||||
signature = comfy_aimdo.model_vbar.vbar_fault(weight._v)
|
||||
if signature is not None:
|
||||
raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device)
|
||||
v_tensor = comfy.memory_management.interpret_gathered_like([r], raw_tensor)[0]
|
||||
|
||||
if comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature):
|
||||
v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, raw_tensor)[0]
|
||||
if not comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature):
|
||||
weight._v_signature = signature
|
||||
#Send it over
|
||||
v_tensor.copy_(weight, non_blocking=non_blocking)
|
||||
#always take a deep copy even if _v is good, as we have no reasonable point to unpin
|
||||
#a non comfy weight
|
||||
r.copy_(v_tensor)
|
||||
comfy_aimdo.model_vbar.vbar_unpin(weight._v)
|
||||
return r
|
||||
|
||||
if weight.dtype != r.dtype and weight.dtype != weight._model_dtype:
|
||||
#Offloaded casting could skip this, however it would make the quantizations
|
||||
#inconsistent between loaded and offloaded weights. So force the double casting
|
||||
#that would happen in regular flow to make offload deterministic.
|
||||
cast_buffer = torch.empty_like(weight, dtype=weight._model_dtype, device=device)
|
||||
cast_buffer.copy_(weight, non_blocking=non_blocking)
|
||||
weight = cast_buffer
|
||||
r.copy_(weight, non_blocking=non_blocking)
|
||||
|
||||
if signature is not None:
|
||||
weight._v_signature = signature
|
||||
v_tensor.copy_(r)
|
||||
comfy_aimdo.model_vbar.vbar_unpin(weight._v)
|
||||
|
||||
return r
|
||||
|
||||
if device is None or weight.device == device:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user