mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-15 13:02:35 +08:00
Merge branch 'master' into feat/aggressive-offload
This commit is contained in:
commit
0bd2a353bf
@ -6,6 +6,7 @@ import comfy.utils
|
||||
import folder_paths
|
||||
from typing_extensions import override
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
import comfy.model_management
|
||||
|
||||
try:
|
||||
from spandrel_extra_arches import EXTRA_REGISTRY
|
||||
@ -78,13 +79,15 @@ class ImageUpscaleWithModel(io.ComfyNode):
|
||||
tile = 512
|
||||
overlap = 32
|
||||
|
||||
output_device = comfy.model_management.intermediate_device()
|
||||
|
||||
oom = True
|
||||
try:
|
||||
while oom:
|
||||
try:
|
||||
steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap)
|
||||
pbar = comfy.utils.ProgressBar(steps)
|
||||
s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar)
|
||||
s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a.float()), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar, output_device=output_device)
|
||||
oom = False
|
||||
except Exception as e:
|
||||
model_management.raise_non_oom(e)
|
||||
@ -94,7 +97,7 @@ class ImageUpscaleWithModel(io.ComfyNode):
|
||||
finally:
|
||||
upscale_model.to("cpu")
|
||||
|
||||
s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0)
|
||||
s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0).to(comfy.model_management.intermediate_dtype())
|
||||
return io.NodeOutput(s)
|
||||
|
||||
upscale = execute # TODO: remove
|
||||
|
||||
@ -18,6 +18,7 @@ class FakeLinearModel(nn.Module):
|
||||
"""Minimal nn.Module whose parameters consume measurable memory."""
|
||||
|
||||
def __init__(self, size_mb: float = 2.0):
|
||||
"""Create a model with approximately ``size_mb`` MB of float32 params."""
|
||||
super().__init__()
|
||||
# Each float32 param = 4 bytes, so `n` params ≈ size_mb * 1024² / 4
|
||||
n = int(size_mb * 1024 * 1024 / 4)
|
||||
@ -28,13 +29,16 @@ class FakeModelPatcher:
|
||||
"""Mimics the subset of ModelPatcher used by model_management.free_memory."""
|
||||
|
||||
def __init__(self, size_mb: float = 2.0):
|
||||
"""Create a patcher wrapping a ``FakeLinearModel`` of the given size."""
|
||||
self.model = FakeLinearModel(size_mb)
|
||||
self._loaded_size = int(size_mb * 1024 * 1024)
|
||||
|
||||
def loaded_size(self):
|
||||
"""Return reported loaded size in bytes."""
|
||||
return self._loaded_size
|
||||
|
||||
def is_dynamic(self):
|
||||
"""Static model — never dynamically offloaded."""
|
||||
return False
|
||||
|
||||
|
||||
@ -42,20 +46,25 @@ class FakeLoadedModel:
|
||||
"""Mimics LoadedModel entries in current_loaded_models."""
|
||||
|
||||
def __init__(self, patcher: FakeModelPatcher, *, currently_used: bool = False):
|
||||
"""Wrap a ``FakeModelPatcher`` as a loaded-model entry."""
|
||||
self._model = patcher
|
||||
self.currently_used = currently_used
|
||||
|
||||
@property
|
||||
def model(self):
|
||||
"""Return the underlying model patcher."""
|
||||
return self._model
|
||||
|
||||
def model_memory(self):
|
||||
"""Return memory footprint in bytes."""
|
||||
return self._model.loaded_size()
|
||||
|
||||
def model_unload(self, _memory_to_free):
|
||||
"""Simulate successful unload."""
|
||||
return True
|
||||
|
||||
def model_load(self, _device, _keep_loaded):
|
||||
"""No-op load for testing."""
|
||||
pass
|
||||
|
||||
|
||||
@ -142,18 +151,19 @@ class TestModelDestroyedCallbacks:
|
||||
"""Validate the on_model_destroyed lifecycle callback system."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Reset the callback list before every test."""
|
||||
"""Save and reset the callback list before every test."""
|
||||
import comfy.model_management as mm
|
||||
self._original = mm._on_model_destroyed_callbacks.copy()
|
||||
mm._on_model_destroyed_callbacks.clear()
|
||||
|
||||
def teardown_method(self):
|
||||
"""Restore the original callback list."""
|
||||
"""Restore the original callback list after every test."""
|
||||
import comfy.model_management as mm
|
||||
mm._on_model_destroyed_callbacks.clear()
|
||||
mm._on_model_destroyed_callbacks.extend(self._original)
|
||||
|
||||
def test_register_single_callback(self):
|
||||
"""A single registered callback must be stored and callable."""
|
||||
import comfy.model_management as mm
|
||||
|
||||
invocations = []
|
||||
@ -269,7 +279,7 @@ class TestMpsFlushConditionality:
|
||||
mm.AGGRESSIVE_OFFLOAD = original
|
||||
|
||||
def test_flush_requires_mps_device(self):
|
||||
"""The flush condition checks device.type == 'mps'."""
|
||||
"""The flush must only activate on MPS devices, not CPU or CUDA."""
|
||||
# Simulate CPU device — flush should not activate
|
||||
cpu_device = torch.device("cpu")
|
||||
assert cpu_device.type != "mps"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user