diff --git a/comfy_extras/nodes_upscale_model.py b/comfy_extras/nodes_upscale_model.py index db4f9d231..d3ee3f1c1 100644 --- a/comfy_extras/nodes_upscale_model.py +++ b/comfy_extras/nodes_upscale_model.py @@ -6,6 +6,7 @@ import comfy.utils import folder_paths from typing_extensions import override from comfy_api.latest import ComfyExtension, io +import comfy.model_management try: from spandrel_extra_arches import EXTRA_REGISTRY @@ -78,13 +79,15 @@ class ImageUpscaleWithModel(io.ComfyNode): tile = 512 overlap = 32 + output_device = comfy.model_management.intermediate_device() + oom = True try: while oom: try: steps = in_img.shape[0] * comfy.utils.get_tiled_scale_steps(in_img.shape[3], in_img.shape[2], tile_x=tile, tile_y=tile, overlap=overlap) pbar = comfy.utils.ProgressBar(steps) - s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar) + s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a.float()), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar, output_device=output_device) oom = False except Exception as e: model_management.raise_non_oom(e) @@ -94,7 +97,7 @@ class ImageUpscaleWithModel(io.ComfyNode): finally: upscale_model.to("cpu") - s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0) + s = torch.clamp(s.movedim(-3,-1), min=0, max=1.0).to(comfy.model_management.intermediate_dtype()) return io.NodeOutput(s) upscale = execute # TODO: remove diff --git a/tests-unit/test_aggressive_offload.py b/tests-unit/test_aggressive_offload.py index 6caa468d1..130023242 100644 --- a/tests-unit/test_aggressive_offload.py +++ b/tests-unit/test_aggressive_offload.py @@ -18,6 +18,7 @@ class FakeLinearModel(nn.Module): """Minimal nn.Module whose parameters consume measurable memory.""" def __init__(self, size_mb: float = 2.0): + """Create a model with approximately ``size_mb`` MB of float32 params.""" super().__init__() # Each float32 param = 4 bytes, so `n` params ≈ size_mb * 1024² / 4 n = int(size_mb * 1024 * 1024 / 4) @@ -28,13 +29,16 @@ class FakeModelPatcher: """Mimics the subset of ModelPatcher used by model_management.free_memory.""" def __init__(self, size_mb: float = 2.0): + """Create a patcher wrapping a ``FakeLinearModel`` of the given size.""" self.model = FakeLinearModel(size_mb) self._loaded_size = int(size_mb * 1024 * 1024) def loaded_size(self): + """Return reported loaded size in bytes.""" return self._loaded_size def is_dynamic(self): + """Static model — never dynamically offloaded.""" return False @@ -42,20 +46,25 @@ class FakeLoadedModel: """Mimics LoadedModel entries in current_loaded_models.""" def __init__(self, patcher: FakeModelPatcher, *, currently_used: bool = False): + """Wrap a ``FakeModelPatcher`` as a loaded-model entry.""" self._model = patcher self.currently_used = currently_used @property def model(self): + """Return the underlying model patcher.""" return self._model def model_memory(self): + """Return memory footprint in bytes.""" return self._model.loaded_size() def model_unload(self, _memory_to_free): + """Simulate successful unload.""" return True def model_load(self, _device, _keep_loaded): + """No-op load for testing.""" pass @@ -142,18 +151,19 @@ class TestModelDestroyedCallbacks: """Validate the on_model_destroyed lifecycle callback system.""" def setup_method(self): - """Reset the callback list before every test.""" + """Save and reset the callback list before every test.""" import comfy.model_management as mm self._original = mm._on_model_destroyed_callbacks.copy() mm._on_model_destroyed_callbacks.clear() def teardown_method(self): - """Restore the original callback list.""" + """Restore the original callback list after every test.""" import comfy.model_management as mm mm._on_model_destroyed_callbacks.clear() mm._on_model_destroyed_callbacks.extend(self._original) def test_register_single_callback(self): + """A single registered callback must be stored and callable.""" import comfy.model_management as mm invocations = [] @@ -269,7 +279,7 @@ class TestMpsFlushConditionality: mm.AGGRESSIVE_OFFLOAD = original def test_flush_requires_mps_device(self): - """The flush condition checks device.type == 'mps'.""" + """The flush must only activate on MPS devices, not CPU or CUDA.""" # Simulate CPU device — flush should not activate cpu_device = torch.device("cpu") assert cpu_device.type != "mps"