mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-11 20:27:44 +08:00
load non comfy weights
This commit is contained in:
parent
6198f7562e
commit
959b3014bb
@ -1206,8 +1206,6 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str
|
|||||||
if dtype is None:
|
if dtype is None:
|
||||||
dtype = weight._model_dtype
|
dtype = weight._model_dtype
|
||||||
|
|
||||||
r = torch.empty_like(weight, dtype=dtype, device=device)
|
|
||||||
|
|
||||||
signature = comfy_aimdo.model_vbar.vbar_fault(weight._v)
|
signature = comfy_aimdo.model_vbar.vbar_fault(weight._v)
|
||||||
if signature is not None:
|
if signature is not None:
|
||||||
raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device)
|
raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device)
|
||||||
@ -1216,11 +1214,9 @@ def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, str
|
|||||||
weight._v_signature = signature
|
weight._v_signature = signature
|
||||||
#Send it over
|
#Send it over
|
||||||
v_tensor.copy_(weight, non_blocking=non_blocking)
|
v_tensor.copy_(weight, non_blocking=non_blocking)
|
||||||
#always take a deep copy even if _v is good, as we have no reasonable point to unpin
|
return v_tensor.to(dtype=dtype)
|
||||||
#a non comfy weight
|
|
||||||
r.copy_(v_tensor)
|
r = torch.empty_like(weight, dtype=dtype, device=device)
|
||||||
comfy_aimdo.model_vbar.vbar_unpin(weight._v)
|
|
||||||
return r
|
|
||||||
|
|
||||||
if weight.dtype != r.dtype and weight.dtype != weight._model_dtype:
|
if weight.dtype != r.dtype and weight.dtype != weight._model_dtype:
|
||||||
#Offloaded casting could skip this, however it would make the quantizations
|
#Offloaded casting could skip this, however it would make the quantizations
|
||||||
|
|||||||
@ -1492,7 +1492,9 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
if vbar is not None:
|
if vbar is not None:
|
||||||
vbar.prioritize()
|
vbar.prioritize()
|
||||||
|
|
||||||
#We have way more tools for acceleration on comfy weight offloading, so always
|
#We force reserve VRAM for the non comfy-weight so we dont have to deal
|
||||||
|
#with pin and unpin syncrhonization which can be expensive for small weights
|
||||||
|
#with a high layer rate (e.g. autoregressive LLMs).
|
||||||
#prioritize the non-comfy weights (note the order reverse).
|
#prioritize the non-comfy weights (note the order reverse).
|
||||||
loading = self._load_list(prio_comfy_cast_weights=True)
|
loading = self._load_list(prio_comfy_cast_weights=True)
|
||||||
loading.sort(reverse=True)
|
loading.sort(reverse=True)
|
||||||
@ -1557,6 +1559,7 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
weight._v = vbar.alloc(weight_size)
|
weight._v = vbar.alloc(weight_size)
|
||||||
weight._model_dtype = model_dtype
|
weight._model_dtype = model_dtype
|
||||||
allocated_size += weight_size
|
allocated_size += weight_size
|
||||||
|
vbar.set_watermark_limit(allocated_size)
|
||||||
|
|
||||||
logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.")
|
logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.")
|
||||||
|
|
||||||
|
|||||||
@ -13,8 +13,11 @@ from contextlib import nullcontext
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from comfy.cli_args import args
|
||||||
import comfy.memory_management
|
import comfy.memory_management
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
import comfy_aimdo.model_vbar
|
||||||
|
|
||||||
from latent_preview import set_preview_method
|
from latent_preview import set_preview_method
|
||||||
import nodes
|
import nodes
|
||||||
from comfy_execution.caching import (
|
from comfy_execution.caching import (
|
||||||
@ -527,8 +530,10 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
|
output_data, output_ui, has_subgraph, has_pending_tasks = await get_output_data(prompt_id, unique_id, obj, input_data_all, execution_block_cb=execution_block_cb, pre_execute_cb=pre_execute_cb, v3_data=v3_data)
|
||||||
finally:
|
finally:
|
||||||
if allocator is not None:
|
if allocator is not None:
|
||||||
|
if args.verbose == "DEBUG":
|
||||||
|
comfy_aimdo.model_vbar.vbars_analyze()
|
||||||
comfy.model_management.reset_cast_buffers()
|
comfy.model_management.reset_cast_buffers()
|
||||||
torch.cuda.synchronize()
|
comfy_aimdo.model_vbar.vbars_reset_watermark_limits()
|
||||||
|
|
||||||
if has_pending_tasks:
|
if has_pending_tasks:
|
||||||
pending_async_nodes[unique_id] = output_data
|
pending_async_nodes[unique_id] = output_data
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user