This commit is contained in:
Rattus 2026-01-13 20:29:13 +10:00
parent 94a709d813
commit 013f132085
5 changed files with 3 additions and 11 deletions

View File

@ -1,10 +1,7 @@
import torch
from comfy.quant_ops import QuantizedTensor
import comfy_aimdo.torch
import logging
def vram_aligned_size(tensor):
if isinstance(tensor, list):
return sum([vram_aligned_size(t) for t in tensor])

View File

@ -24,7 +24,6 @@ import inspect
import logging
import math
import uuid
import types
from typing import Callable, Optional
import torch
@ -1381,7 +1380,7 @@ class ModelPatcher:
unet_state_dict = self.model.diffusion_model.state_dict()
for k, v in unet_state_dict.items():
op_keys = k.rsplit('.', 1)
if (len(op_keys) < 2) or not op_keys[1] in ["weight", "bias"]:
if (len(op_keys) < 2) or op_keys[1] not in ["weight", "bias"]:
continue
try:
op = comfy.utils.get_attr(self.model.diffusion_model, op_keys[0])
@ -1467,7 +1466,7 @@ class ModelPatcherDynamic(ModelPatcher):
#Full load doesn't make sense as we dont actually have any loader capability here and
#now.
assert not full_load;
assert not full_load
assert device_to == self.load_device

View File

@ -128,11 +128,9 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
def post_cast(s, param_key, x, dtype, resident, update_weight):
lowvram_fn = getattr(s, param_key + "_lowvram_function", None)
hook_fn = getattr(s, param_key + "_hooks", None)
fns = getattr(s, param_key + "_function", [])
orig = x
q_layout = None
def to_dequant(tensor, dtype):
tensor = tensor.to(dtype=dtype)
@ -218,7 +216,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
if cast_buffer is None:
offload_stream = comfy.model_management.get_offload_stream(device)
cast_buffer = comfy.model_management.get_cast_buffer(offload_stream, device, cast_buffer_size, s)
params = interpret_gathered_like([ s.weight, s.bias ], cast_buffer)
params = comfy.memory_management.interpret_gathered_like([ s.weight, s.bias ], cast_buffer)
weight = params[0]
bias = params[1]

View File

@ -1,5 +1,4 @@
import torch
import logging
import comfy.model_management
import comfy.memory_management

View File

@ -9,7 +9,6 @@ if TYPE_CHECKING:
import torch
from functools import partial
import collections
from comfy import model_management
import math
import logging
import comfy.sampler_helpers