Merge branch 'master' into feat/aggressive-offload

This commit is contained in:
Julián Mulet 2026-04-12 01:22:34 +02:00 committed by Julián Mulet
commit 0bd2a353bf
No known key found for this signature in database
2 changed files with 18 additions and 5 deletions

View File

@ -6,6 +6,7 @@ import comfy.utils
import folder_paths
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
import comfy.model_management
try:
from spandrel_extra_arches import EXTRA_REGISTRY
@ -78,13 +79,15 @@ class ImageUpscaleWithModel(io.ComfyNode):
tile = 512
overlap = 32
output_device = comfy.model_management.intermediate_device()
oom = True
try:
while oom:
try:
steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap)
pbar = comfy.utils.ProgressBar(steps)
s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar)
s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a.float()), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar, output_device=output_device)
oom = False
except Exception as e:
model_management.raise_non_oom(e)
@ -94,7 +97,7 @@ class ImageUpscaleWithModel(io.ComfyNode):
finally:
upscale_model.to("cpu")
s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0)
s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0).to(comfy.model_management.intermediate_dtype())
return io.NodeOutput(s)
upscale = execute # TODO: remove

View File

@ -18,6 +18,7 @@ class FakeLinearModel(nn.Module):
"""Minimal nn.Module whose parameters consume measurable memory."""
def __init__(self, size_mb: float = 2.0):
"""Create a model with approximately ``size_mb`` MB of float32 params."""
super().__init__()
# Each float32 param = 4 bytes, so `n` params ≈ size_mb * 1024² / 4
n = int(size_mb * 1024 * 1024 / 4)
@ -28,13 +29,16 @@ class FakeModelPatcher:
"""Mimics the subset of ModelPatcher used by model_management.free_memory."""
def __init__(self, size_mb: float = 2.0):
"""Create a patcher wrapping a ``FakeLinearModel`` of the given size."""
self.model = FakeLinearModel(size_mb)
self._loaded_size = int(size_mb * 1024 * 1024)
def loaded_size(self):
"""Return reported loaded size in bytes."""
return self._loaded_size
def is_dynamic(self):
"""Static model — never dynamically offloaded."""
return False
@ -42,20 +46,25 @@ class FakeLoadedModel:
"""Mimics LoadedModel entries in current_loaded_models."""
def __init__(self, patcher: FakeModelPatcher, *, currently_used: bool = False):
"""Wrap a ``FakeModelPatcher`` as a loaded-model entry."""
self._model = patcher
self.currently_used = currently_used
@property
def model(self):
"""Return the underlying model patcher."""
return self._model
def model_memory(self):
"""Return memory footprint in bytes."""
return self._model.loaded_size()
def model_unload(self, _memory_to_free):
"""Simulate successful unload."""
return True
def model_load(self, _device, _keep_loaded):
"""No-op load for testing."""
pass
@ -142,18 +151,19 @@ class TestModelDestroyedCallbacks:
"""Validate the on_model_destroyed lifecycle callback system."""
def setup_method(self):
"""Reset the callback list before every test."""
"""Save and reset the callback list before every test."""
import comfy.model_management as mm
self._original = mm._on_model_destroyed_callbacks.copy()
mm._on_model_destroyed_callbacks.clear()
def teardown_method(self):
"""Restore the original callback list."""
"""Restore the original callback list after every test."""
import comfy.model_management as mm
mm._on_model_destroyed_callbacks.clear()
mm._on_model_destroyed_callbacks.extend(self._original)
def test_register_single_callback(self):
"""A single registered callback must be stored and callable."""
import comfy.model_management as mm
invocations = []
@ -269,7 +279,7 @@ class TestMpsFlushConditionality:
mm.AGGRESSIVE_OFFLOAD = original
def test_flush_requires_mps_device(self):
"""The flush condition checks device.type == 'mps'."""
"""The flush must only activate on MPS devices, not CPU or CUDA."""
# Simulate CPU device — flush should not activate
cpu_device = torch.device("cpu")
assert cpu_device.type != "mps"