dynamic_vram: respect argument cast dtypes in non-comfy weights (#12209)

This function has a dtype argument that allows the caller to set the
dtype in the cast. TIL Some models override this on weight casts, which
means its the highest priority.

Priority scheme is: argument > model dtype > state dict dtype
This commit is contained in:
rattus 2026-02-01 17:09:21 -08:00 committed by GitHub
parent 361b9a82a3
commit 794d05bdb1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1202,27 +1202,36 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str
assert r is None
assert stream is None
r = torch.empty_like(weight, dtype=weight._model_dtype, device=device)
cast_geometry = comfy.memory_management.tensors_to_geometries([ weight ])
if dtype is None:
dtype = weight._model_dtype
r = torch.empty_like(weight, dtype=dtype, device=device)
signature = comfy_aimdo.model_vbar.vbar_fault(weight._v)
if signature is not None:
raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device)
v_tensor = comfy.memory_management.interpret_gathered_like([r], raw_tensor)[0]
if comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature):
v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, raw_tensor)[0]
if not comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature):
weight._v_signature = signature
#Send it over
v_tensor.copy_(weight, non_blocking=non_blocking)
#always take a deep copy even if _v is good, as we have no reasonable point to unpin
#a non comfy weight
r.copy_(v_tensor)
comfy_aimdo.model_vbar.vbar_unpin(weight._v)
return r
if weight.dtype != r.dtype and weight.dtype != weight._model_dtype:
#Offloaded casting could skip this, however it would make the quantizations
#inconsistent between loaded and offloaded weights. So force the double casting
#that would happen in regular flow to make offload deterministic.
cast_buffer = torch.empty_like(weight, dtype=weight._model_dtype, device=device)
cast_buffer.copy_(weight, non_blocking=non_blocking)
weight = cast_buffer
r.copy_(weight, non_blocking=non_blocking)
if signature is not None:
weight._v_signature = signature
v_tensor.copy_(r)
comfy_aimdo.model_vbar.vbar_unpin(weight._v)
return r
if device is None or weight.device == device: