From abfea891efb1617099f52d7c13487aa46e541cde Mon Sep 17 00:00:00 2001 From: RyanOnTheInside <7623207+ryanontheinside@users.noreply.github.com> Date: Sun, 8 Feb 2026 15:07:20 -0500 Subject: [PATCH 1/2] Fix conditioning mask normalization for arbitrary spatial dimensions. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit May also resolve #9784 — the mask normalization fixes a class of dimensionality mismatches that can cause the `y, x = torch.where(mask)` crash in `get_mask_aabb`, though the root cause in that report is unconfirmed. ## Summary `resolve_areas_and_cond_masks_multidim` assumes 2D spatial masks. This breaks for 1D audio models (StableAudio1, ACEAudio15) because upstream code (`ConditioningSetMask`, `set_mask_for_conditioning`) unconditionally unsqueezes masks with `ndim < 3`, corrupting valid `[B, L]` masks into `[1, B, L]` before they reach the sampler. This PR: - Normalizes masks to `[batch, *spatial_dims]` using `dims` as the source of truth - Adds a 1D resize path via `F.interpolate(mode='linear')` - Guards `set_area_to_bounds` with `len(dims) == 2` to prevent crashes on non-2D masks (the existing `get_mask_aabb` and `H, W, Y, X` unpacking are 2D-only) The root cause is the hardcoded `if len(mask.shape) < 3` in `nodes.py:242` and `hooks.py:725`. Fixing it there would require threading latent dimensionality into the conditioning nodes — a much larger change. Normalizing in `resolve_areas_and_cond_masks_multidim` where `dims` is already available is the minimal fix. Fully backwards compatible for existing 2D image and 3D video workflows. ## Test plan - [x] 26 unit tests covering 1D/2D/3D mask normalization, resize, and `set_area_to_bounds` guard (`tests-unit/comfy_test/samplers_test.py`) - [x] 2D image regression with hook masking: [lorahookmasking.json](https://github.com/Kosinkadink/ComfyUI/blob/workflows/lorahookmasking.json) - [x] 2D image with `set_area_to_bounds` ("mask bounds" mode) — no crash, correct area computation - [x] 1D audio with conditioning mask: [acestep-1.5-prompt-lora-blending.json](https://github.com/ryanontheinside/ComfyUI_RyanOnTheInside/blob/main/examples/ace1.5/acestep-1.5-prompt-lora-blending.json) (requires custom nodes that patch this function pending upstream) --- comfy/samplers.py | 14 +- tests-unit/comfy_test/samplers_test.py | 230 +++++++++++++++++++++++++ 2 files changed, 241 insertions(+), 3 deletions(-) create mode 100644 tests-unit/comfy_test/samplers_test.py diff --git a/comfy/samplers.py b/comfy/samplers.py index 8b9782956..8b530b558 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -543,15 +543,23 @@ def resolve_areas_and_cond_masks_multidim(conditions, dims, device): mask = c['mask'] mask = mask.to(device=device) modified = c.copy() - if len(mask.shape) == len(dims): + # Normalize mask to [batch, *spatial_dims] + target_ndim = len(dims) + 1 + while mask.ndim > target_ndim and mask.shape[0] == 1: + mask = mask.squeeze(0) + while mask.ndim < target_ndim: mask = mask.unsqueeze(0) if mask.shape[1:] != dims: - if mask.ndim < 4: + if len(dims) == 1: + mask = torch.nn.functional.interpolate( + mask.unsqueeze(1), size=dims[0], + mode='linear', align_corners=False).squeeze(1) + elif mask.ndim < 4: mask = comfy.utils.common_upscale(mask.unsqueeze(1), dims[-1], dims[-2], 'bilinear', 'none').squeeze(1) else: mask = comfy.utils.common_upscale(mask, dims[-1], dims[-2], 'bilinear', 'none') - if modified.get("set_area_to_bounds", False): #TODO: handle dim != 2 + if modified.get("set_area_to_bounds", False) and len(dims) == 2: bounds = torch.max(torch.abs(mask),dim=0).values.unsqueeze(0) boxes, is_empty = get_mask_aabb(bounds) if is_empty[0]: diff --git a/tests-unit/comfy_test/samplers_test.py b/tests-unit/comfy_test/samplers_test.py new file mode 100644 index 000000000..1325a5c9a --- /dev/null +++ b/tests-unit/comfy_test/samplers_test.py @@ -0,0 +1,230 @@ +""" +Tests for resolve_areas_and_cond_masks_multidim mask normalization and resizing. + +Covers 1D (audio), 2D (image), and 3D (video) spatial dims, including edge cases +around batch size, spurious unsqueeze from upstream nodes, and size mismatches. +""" + +import torch + +from comfy.samplers import resolve_areas_and_cond_masks_multidim + + +def make_cond(mask): + """Create a minimal conditioning dict with a mask.""" + return {"mask": mask, "model_conds": {}} + + +def run_resolve(mask, dims, device="cpu"): + """Run resolve on a single condition and return the resolved mask.""" + conds = [make_cond(mask)] + resolve_areas_and_cond_masks_multidim(conds, dims, device) + return conds[0]["mask"] + + +# ============================================================ +# 1D spatial dims (audio models like AceStep v1.5) +# dims = (length,), expected mask output: [batch, length] +# ============================================================ + +class Test1DSpatial: + """Tests for 1D spatial models (e.g. audio with noise shape [B, C, L]).""" + + def test_correct_shape_same_length(self): + """[B, L] mask with matching length — should pass through unchanged.""" + mask = torch.ones(2, 100) + result = run_resolve(mask, dims=(100,)) + assert result.shape == (2, 100) + + def test_correct_shape_resize(self): + """[B, L] mask with different length — should resize via linear interp.""" + mask = torch.ones(1, 50) + result = run_resolve(mask, dims=(100,)) + assert result.shape == (1, 100) + + def test_bare_spatial_mask(self): + """[L] mask (no batch) — should get batch dim added.""" + mask = torch.ones(50) + result = run_resolve(mask, dims=(100,)) + assert result.shape == (1, 100) + + def test_spurious_unsqueeze_from_hooks(self): + """[1, B, L] mask (from set_mask_for_conditioning unsqueezing a [B, L] mask) + — should squeeze back to [B, L].""" + # Simulates: mask is [B, L], hooks.py does unsqueeze(0) -> [1, B, L] + mask = torch.ones(1, 2, 100) + result = run_resolve(mask, dims=(100,)) + assert result.shape == (2, 100) + + def test_spurious_unsqueeze_batch1(self): + """[1, 1, L] mask (batch=1, hooks added extra dim) — should become [1, L].""" + mask = torch.ones(1, 1, 50) + result = run_resolve(mask, dims=(100,)) + assert result.shape == (1, 100) + + def test_batch_gt1_same_length(self): + """[B, L] mask with batch=4 and matching length — no changes needed.""" + mask = torch.rand(4, 100) + result = run_resolve(mask, dims=(100,)) + assert result.shape == (4, 100) + torch.testing.assert_close(result, mask) + + def test_batch_gt1_resize(self): + """[B, L] mask with batch=4 and different length — should resize each batch.""" + mask = torch.rand(4, 50) + result = run_resolve(mask, dims=(100,)) + assert result.shape == (4, 100) + + def test_values_preserved_no_resize(self): + """Mask values should be preserved when no resize is needed.""" + mask = torch.tensor([[0.0, 0.5, 1.0]]) + result = run_resolve(mask, dims=(3,)) + torch.testing.assert_close(result, mask) + + def test_linear_interpolation_values(self): + """Check that linear interpolation produces sensible values.""" + mask = torch.tensor([[0.0, 1.0]]) # [1, 2] + result = run_resolve(mask, dims=(5,)) + assert result.shape == (1, 5) + # Should interpolate from 0 to 1 + assert result[0, 0].item() < result[0, -1].item() + + def test_set_area_to_bounds_skipped_for_1d(self): + """set_area_to_bounds should be skipped for 1D (no crash).""" + mask = torch.zeros(1, 100) + mask[0, 10:50] = 1.0 + conds = [{"mask": mask, "model_conds": {}, "set_area_to_bounds": True}] + resolve_areas_and_cond_masks_multidim(conds, (100,), "cpu") + assert "area" not in conds[0] + + +# ============================================================ +# 2D spatial dims (image models) — regression tests +# dims = (H, W), expected mask output: [batch, H, W] +# ============================================================ + +class Test2DSpatial: + """Regression tests for standard 2D image models.""" + + def test_correct_shape_same_size(self): + """[B, H, W] mask matching dims — pass through.""" + mask = torch.ones(1, 64, 64) + result = run_resolve(mask, dims=(64, 64)) + assert result.shape == (1, 64, 64) + + def test_bare_spatial_mask(self): + """[H, W] mask — should get batch dim added.""" + mask = torch.ones(64, 64) + result = run_resolve(mask, dims=(64, 64)) + assert result.shape == (1, 64, 64) + + def test_resize_different_resolution(self): + """[B, H1, W1] mask with different size than dims — should bilinear resize.""" + mask = torch.ones(1, 32, 32) + result = run_resolve(mask, dims=(64, 64)) + assert result.shape == (1, 64, 64) + + def test_4d_mask(self): + """[B, C, H, W] mask (4D) — should resize via common_upscale 4D path.""" + mask = torch.ones(1, 1, 32, 32) + result = run_resolve(mask, dims=(64, 64)) + assert result.shape == (1, 64, 64) + + def test_batch_gt1(self): + """[B, H, W] mask with batch > 1.""" + mask = torch.rand(4, 64, 64) + result = run_resolve(mask, dims=(64, 64)) + assert result.shape == (4, 64, 64) + + def test_batch_gt1_resize(self): + """[B, H, W] mask with batch > 1 and different resolution.""" + mask = torch.rand(4, 32, 32) + result = run_resolve(mask, dims=(64, 64)) + assert result.shape == (4, 64, 64) + + def test_set_area_to_bounds(self): + """set_area_to_bounds should work for 2D masks.""" + mask = torch.zeros(1, 64, 64) + mask[0, 10:20, 10:30] = 1.0 + conds = [{"mask": mask, "model_conds": {}, "set_area_to_bounds": True}] + resolve_areas_and_cond_masks_multidim(conds, (64, 64), "cpu") + assert "area" in conds[0] + + def test_non_square_resize(self): + """[B, H1, W1] mask resized to non-square dims.""" + mask = torch.ones(1, 16, 32) + result = run_resolve(mask, dims=(64, 128)) + assert result.shape == (1, 64, 128) + + +# ============================================================ +# 3D spatial dims (video models) +# dims = (T, H, W), expected mask output: [batch, T, H, W] +# ============================================================ + +class Test3DSpatial: + """Tests for 3D spatial models (e.g. video with noise shape [B, C, T, H, W]).""" + + def test_correct_shape_same_size(self): + """[B, T, H, W] mask matching dims — pass through.""" + mask = torch.ones(1, 8, 64, 64) + result = run_resolve(mask, dims=(8, 64, 64)) + assert result.shape == (1, 8, 64, 64) + + def test_bare_spatial_mask(self): + """[T, H, W] mask — should get batch dim added.""" + mask = torch.ones(8, 64, 64) + result = run_resolve(mask, dims=(8, 64, 64)) + assert result.shape == (1, 8, 64, 64) + + def test_resize_hw(self): + """[B, T, H1, W1] mask with different H, W — should resize last 2 dims.""" + mask = torch.ones(1, 8, 32, 32) + result = run_resolve(mask, dims=(8, 64, 64)) + assert result.shape == (1, 8, 64, 64) + + def test_set_area_to_bounds_skipped_for_3d(self): + """set_area_to_bounds should be skipped for 3D (no crash).""" + mask = torch.zeros(1, 8, 64, 64) + mask[0, :, 10:20, 10:30] = 1.0 + conds = [{"mask": mask, "model_conds": {}, "set_area_to_bounds": True}] + resolve_areas_and_cond_masks_multidim(conds, (8, 64, 64), "cpu") + assert "area" not in conds[0] + + +class TestNoMask: + """Conditions without masks should pass through untouched.""" + + def test_no_mask_key(self): + """Condition with no mask key — untouched.""" + conds = [{"model_conds": {}}] + resolve_areas_and_cond_masks_multidim(conds, (64, 64), "cpu") + assert "mask" not in conds[0] + + def test_empty_conditions(self): + """Empty conditions list — no crash.""" + conds = [] + resolve_areas_and_cond_masks_multidim(conds, (64, 64), "cpu") + assert len(conds) == 0 + + +# ============================================================ +# Area resolution (percentage-based) +# ============================================================ + +class TestAreaResolution: + """Test that percentage-based area resolution works for different dims.""" + + def test_percentage_area_2d(self): + """Percentage area for 2D should resolve to pixel coords.""" + conds = [{"area": ("percentage", 0.5, 0.5, 0.25, 0.25), "model_conds": {}}] + resolve_areas_and_cond_masks_multidim(conds, (64, 64), "cpu") + area = conds[0]["area"] + assert area == (32, 32, 16, 16) + + def test_percentage_area_1d(self): + """Percentage area for 1D should resolve to frame coords.""" + conds = [{"area": ("percentage", 0.5, 0.25), "model_conds": {}}] + resolve_areas_and_cond_masks_multidim(conds, (100,), "cpu") + area = conds[0]["area"] + assert area == (50, 25) From 9a5e6233f44001025ec935593820e4237bcadf9c Mon Sep 17 00:00:00 2001 From: RyanOnTheInside <7623207+ryanontheinside@users.noreply.github.com> Date: Sun, 15 Feb 2026 10:53:49 -0500 Subject: [PATCH 2/2] Fix reshape_mask for 1D spatial dimensions. reshape_mask sets scale_mode="linear" for dims==1 but is missing the input reshape to [N, C, W] that the 2D and 3D branches both perform. Add the missing reshape, matching the existing pattern. --- comfy/utils.py | 1 + .../{samplers_test.py => mask_test.py} | 112 +++++++++++++++++- 2 files changed, 110 insertions(+), 3 deletions(-) rename tests-unit/comfy_test/{samplers_test.py => mask_test.py} (65%) diff --git a/comfy/utils.py b/comfy/utils.py index c1ce540b5..1d9335b66 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -1241,6 +1241,7 @@ def reshape_mask(input_mask, output_shape): dims = len(output_shape) - 2 if dims == 1: + input_mask = input_mask.reshape((-1, 1, input_mask.shape[-1])) scale_mode = "linear" if dims == 2: diff --git a/tests-unit/comfy_test/samplers_test.py b/tests-unit/comfy_test/mask_test.py similarity index 65% rename from tests-unit/comfy_test/samplers_test.py rename to tests-unit/comfy_test/mask_test.py index 1325a5c9a..7a6e6da99 100644 --- a/tests-unit/comfy_test/samplers_test.py +++ b/tests-unit/comfy_test/mask_test.py @@ -1,13 +1,16 @@ """ -Tests for resolve_areas_and_cond_masks_multidim mask normalization and resizing. +Tests for mask handling across arbitrary spatial dimensions. -Covers 1D (audio), 2D (image), and 3D (video) spatial dims, including edge cases -around batch size, spurious unsqueeze from upstream nodes, and size mismatches. +Covers resolve_areas_and_cond_masks_multidim (conditioning masks) and +reshape_mask (denoise masks) for 1D (audio), 2D (image), and 3D (video) +spatial dims, including edge cases around batch size, spurious unsqueeze +from upstream nodes, and size mismatches. """ import torch from comfy.samplers import resolve_areas_and_cond_masks_multidim +from comfy.utils import reshape_mask def make_cond(mask): @@ -228,3 +231,106 @@ class TestAreaResolution: resolve_areas_and_cond_masks_multidim(conds, (100,), "cpu") area = conds[0]["area"] assert area == (50, 25) + + +# ============================================================ +# reshape_mask — mask reshaping for F.interpolate +# ============================================================ + +class TestReshapeMask1D: + """Tests for reshape_mask with 1D output (e.g. audio with noise shape [B, C, L]).""" + + def test_4d_input_same_length(self): + """[1, 1, 1, L] input (typical from pipeline) — should reshape and expand channels.""" + mask = torch.ones(1, 1, 1, 100) + result = reshape_mask(mask, torch.Size([1, 64, 100])) + assert result.shape == (1, 64, 100) + + def test_4d_input_resize(self): + """[1, 1, 1, L1] input resized to different length.""" + mask = torch.ones(1, 1, 1, 50) + result = reshape_mask(mask, torch.Size([1, 64, 100])) + assert result.shape == (1, 64, 100) + + def test_3d_input(self): + """[1, 1, L] input — should work directly.""" + mask = torch.ones(1, 1, 100) + result = reshape_mask(mask, torch.Size([1, 64, 100])) + assert result.shape == (1, 64, 100) + + def test_2d_input(self): + """[B, L] input — should reshape to [B, 1, L].""" + mask = torch.ones(1, 50) + result = reshape_mask(mask, torch.Size([1, 64, 100])) + assert result.shape == (1, 64, 100) + + def test_1d_input(self): + """[L] input — should reshape to [1, 1, L].""" + mask = torch.ones(50) + result = reshape_mask(mask, torch.Size([1, 64, 100])) + assert result.shape == (1, 64, 100) + + def test_channel_repeat(self): + """Mask with 1 channel should repeat to match output channels.""" + mask = torch.full((1, 1, 1, 100), 0.5) + result = reshape_mask(mask, torch.Size([1, 32, 100])) + assert result.shape == (1, 32, 100) + torch.testing.assert_close(result, torch.full_like(result, 0.5)) + + def test_batch_repeat(self): + """Single-batch mask should repeat to match output batch size.""" + mask = torch.full((1, 1, 1, 100), 0.7) + result = reshape_mask(mask, torch.Size([4, 64, 100])) + assert result.shape == (4, 64, 100) + + def test_values_preserved_no_resize(self): + """Values should be preserved when no resize is needed.""" + values = torch.tensor([[[0.0, 0.5, 1.0]]]) # [1, 1, 3] + result = reshape_mask(values, torch.Size([1, 1, 3])) + torch.testing.assert_close(result, values) + + def test_interpolation_values(self): + """Linear interpolation should produce sensible intermediate values.""" + mask = torch.tensor([[[[0.0, 1.0]]]]) # [1, 1, 1, 2] + result = reshape_mask(mask, torch.Size([1, 1, 4])) + assert result.shape == (1, 1, 4) + # Should interpolate from 0 to 1 + assert result[0, 0, 0].item() < result[0, 0, -1].item() + + +class TestReshapeMask2D: + """Regression tests for reshape_mask with 2D output (image models).""" + + def test_standard_resize(self): + """[1, 1, H, W] mask resized to different resolution.""" + mask = torch.ones(1, 1, 32, 32) + result = reshape_mask(mask, torch.Size([1, 4, 64, 64])) + assert result.shape == (1, 4, 64, 64) + + def test_same_size(self): + """[1, 1, H, W] mask with matching size — no resize needed.""" + mask = torch.rand(1, 1, 64, 64) + result = reshape_mask(mask, torch.Size([1, 4, 64, 64])) + assert result.shape == (1, 4, 64, 64) + + def test_3d_input(self): + """[B, H, W] input — should reshape to [B, 1, H, W].""" + mask = torch.ones(1, 32, 32) + result = reshape_mask(mask, torch.Size([1, 4, 64, 64])) + assert result.shape == (1, 4, 64, 64) + + +class TestReshapeMask3D: + """Regression tests for reshape_mask with 3D output (video models).""" + + def test_standard_resize(self): + """[1, 1, T, H, W] mask resized to different resolution.""" + mask = torch.ones(1, 1, 8, 32, 32) + result = reshape_mask(mask, torch.Size([1, 4, 8, 64, 64])) + assert result.shape == (1, 4, 8, 64, 64) + + def test_4d_input(self): + """[B, T, H, W] input — should reshape to [1, 1, T, H, W].""" + mask = torch.ones(1, 8, 32, 32) + result = reshape_mask(mask, torch.Size([1, 4, 8, 64, 64])) + assert result.shape == (1, 4, 8, 64, 64)