From 9a5e6233f44001025ec935593820e4237bcadf9c Mon Sep 17 00:00:00 2001 From: RyanOnTheInside <7623207+ryanontheinside@users.noreply.github.com> Date: Sun, 15 Feb 2026 10:53:49 -0500 Subject: [PATCH] Fix reshape_mask for 1D spatial dimensions. reshape_mask sets scale_mode="linear" for dims==1 but is missing the input reshape to [N, C, W] that the 2D and 3D branches both perform. Add the missing reshape, matching the existing pattern. --- comfy/utils.py | 1 + .../{samplers_test.py => mask_test.py} | 112 +++++++++++++++++- 2 files changed, 110 insertions(+), 3 deletions(-) rename tests-unit/comfy_test/{samplers_test.py => mask_test.py} (65%) diff --git a/comfy/utils.py b/comfy/utils.py index c1ce540b5..1d9335b66 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -1241,6 +1241,7 @@ def reshape_mask(input_mask, output_shape): dims = len(output_shape) - 2 if dims == 1: + input_mask = input_mask.reshape((-1, 1, input_mask.shape[-1])) scale_mode = "linear" if dims == 2: diff --git a/tests-unit/comfy_test/samplers_test.py b/tests-unit/comfy_test/mask_test.py similarity index 65% rename from tests-unit/comfy_test/samplers_test.py rename to tests-unit/comfy_test/mask_test.py index 1325a5c9a..7a6e6da99 100644 --- a/tests-unit/comfy_test/samplers_test.py +++ b/tests-unit/comfy_test/mask_test.py @@ -1,13 +1,16 @@ """ -Tests for resolve_areas_and_cond_masks_multidim mask normalization and resizing. +Tests for mask handling across arbitrary spatial dimensions. -Covers 1D (audio), 2D (image), and 3D (video) spatial dims, including edge cases -around batch size, spurious unsqueeze from upstream nodes, and size mismatches. +Covers resolve_areas_and_cond_masks_multidim (conditioning masks) and +reshape_mask (denoise masks) for 1D (audio), 2D (image), and 3D (video) +spatial dims, including edge cases around batch size, spurious unsqueeze +from upstream nodes, and size mismatches. """ import torch from comfy.samplers import resolve_areas_and_cond_masks_multidim +from comfy.utils import reshape_mask def make_cond(mask): @@ -228,3 +231,106 @@ class TestAreaResolution: resolve_areas_and_cond_masks_multidim(conds, (100,), "cpu") area = conds[0]["area"] assert area == (50, 25) + + +# ============================================================ +# reshape_mask — mask reshaping for F.interpolate +# ============================================================ + +class TestReshapeMask1D: + """Tests for reshape_mask with 1D output (e.g. audio with noise shape [B, C, L]).""" + + def test_4d_input_same_length(self): + """[1, 1, 1, L] input (typical from pipeline) — should reshape and expand channels.""" + mask = torch.ones(1, 1, 1, 100) + result = reshape_mask(mask, torch.Size([1, 64, 100])) + assert result.shape == (1, 64, 100) + + def test_4d_input_resize(self): + """[1, 1, 1, L1] input resized to different length.""" + mask = torch.ones(1, 1, 1, 50) + result = reshape_mask(mask, torch.Size([1, 64, 100])) + assert result.shape == (1, 64, 100) + + def test_3d_input(self): + """[1, 1, L] input — should work directly.""" + mask = torch.ones(1, 1, 100) + result = reshape_mask(mask, torch.Size([1, 64, 100])) + assert result.shape == (1, 64, 100) + + def test_2d_input(self): + """[B, L] input — should reshape to [B, 1, L].""" + mask = torch.ones(1, 50) + result = reshape_mask(mask, torch.Size([1, 64, 100])) + assert result.shape == (1, 64, 100) + + def test_1d_input(self): + """[L] input — should reshape to [1, 1, L].""" + mask = torch.ones(50) + result = reshape_mask(mask, torch.Size([1, 64, 100])) + assert result.shape == (1, 64, 100) + + def test_channel_repeat(self): + """Mask with 1 channel should repeat to match output channels.""" + mask = torch.full((1, 1, 1, 100), 0.5) + result = reshape_mask(mask, torch.Size([1, 32, 100])) + assert result.shape == (1, 32, 100) + torch.testing.assert_close(result, torch.full_like(result, 0.5)) + + def test_batch_repeat(self): + """Single-batch mask should repeat to match output batch size.""" + mask = torch.full((1, 1, 1, 100), 0.7) + result = reshape_mask(mask, torch.Size([4, 64, 100])) + assert result.shape == (4, 64, 100) + + def test_values_preserved_no_resize(self): + """Values should be preserved when no resize is needed.""" + values = torch.tensor([[[0.0, 0.5, 1.0]]]) # [1, 1, 3] + result = reshape_mask(values, torch.Size([1, 1, 3])) + torch.testing.assert_close(result, values) + + def test_interpolation_values(self): + """Linear interpolation should produce sensible intermediate values.""" + mask = torch.tensor([[[[0.0, 1.0]]]]) # [1, 1, 1, 2] + result = reshape_mask(mask, torch.Size([1, 1, 4])) + assert result.shape == (1, 1, 4) + # Should interpolate from 0 to 1 + assert result[0, 0, 0].item() < result[0, 0, -1].item() + + +class TestReshapeMask2D: + """Regression tests for reshape_mask with 2D output (image models).""" + + def test_standard_resize(self): + """[1, 1, H, W] mask resized to different resolution.""" + mask = torch.ones(1, 1, 32, 32) + result = reshape_mask(mask, torch.Size([1, 4, 64, 64])) + assert result.shape == (1, 4, 64, 64) + + def test_same_size(self): + """[1, 1, H, W] mask with matching size — no resize needed.""" + mask = torch.rand(1, 1, 64, 64) + result = reshape_mask(mask, torch.Size([1, 4, 64, 64])) + assert result.shape == (1, 4, 64, 64) + + def test_3d_input(self): + """[B, H, W] input — should reshape to [B, 1, H, W].""" + mask = torch.ones(1, 32, 32) + result = reshape_mask(mask, torch.Size([1, 4, 64, 64])) + assert result.shape == (1, 4, 64, 64) + + +class TestReshapeMask3D: + """Regression tests for reshape_mask with 3D output (video models).""" + + def test_standard_resize(self): + """[1, 1, T, H, W] mask resized to different resolution.""" + mask = torch.ones(1, 1, 8, 32, 32) + result = reshape_mask(mask, torch.Size([1, 4, 8, 64, 64])) + assert result.shape == (1, 4, 8, 64, 64) + + def test_4d_input(self): + """[B, T, H, W] input — should reshape to [1, 1, T, H, W].""" + mask = torch.ones(1, 8, 32, 32) + result = reshape_mask(mask, torch.Size([1, 4, 8, 64, 64])) + assert result.shape == (1, 4, 8, 64, 64)