From abfea891efb1617099f52d7c13487aa46e541cde Mon Sep 17 00:00:00 2001 From: RyanOnTheInside <7623207+ryanontheinside@users.noreply.github.com> Date: Sun, 8 Feb 2026 15:07:20 -0500 Subject: [PATCH] Fix conditioning mask normalization for arbitrary spatial dimensions. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit May also resolve #9784 — the mask normalization fixes a class of dimensionality mismatches that can cause the `y, x = torch.where(mask)` crash in `get_mask_aabb`, though the root cause in that report is unconfirmed. ## Summary `resolve_areas_and_cond_masks_multidim` assumes 2D spatial masks. This breaks for 1D audio models (StableAudio1, ACEAudio15) because upstream code (`ConditioningSetMask`, `set_mask_for_conditioning`) unconditionally unsqueezes masks with `ndim < 3`, corrupting valid `[B, L]` masks into `[1, B, L]` before they reach the sampler. This PR: - Normalizes masks to `[batch, *spatial_dims]` using `dims` as the source of truth - Adds a 1D resize path via `F.interpolate(mode='linear')` - Guards `set_area_to_bounds` with `len(dims) == 2` to prevent crashes on non-2D masks (the existing `get_mask_aabb` and `H, W, Y, X` unpacking are 2D-only) The root cause is the hardcoded `if len(mask.shape) < 3` in `nodes.py:242` and `hooks.py:725`. Fixing it there would require threading latent dimensionality into the conditioning nodes — a much larger change. Normalizing in `resolve_areas_and_cond_masks_multidim` where `dims` is already available is the minimal fix. Fully backwards compatible for existing 2D image and 3D video workflows. ## Test plan - [x] 26 unit tests covering 1D/2D/3D mask normalization, resize, and `set_area_to_bounds` guard (`tests-unit/comfy_test/samplers_test.py`) - [x] 2D image regression with hook masking: [lorahookmasking.json](https://github.com/Kosinkadink/ComfyUI/blob/workflows/lorahookmasking.json) - [x] 2D image with `set_area_to_bounds` ("mask bounds" mode) — no crash, correct area computation - [x] 1D audio with conditioning mask: [acestep-1.5-prompt-lora-blending.json](https://github.com/ryanontheinside/ComfyUI_RyanOnTheInside/blob/main/examples/ace1.5/acestep-1.5-prompt-lora-blending.json) (requires custom nodes that patch this function pending upstream) --- comfy/samplers.py | 14 +- tests-unit/comfy_test/samplers_test.py | 230 +++++++++++++++++++++++++ 2 files changed, 241 insertions(+), 3 deletions(-) create mode 100644 tests-unit/comfy_test/samplers_test.py diff --git a/comfy/samplers.py b/comfy/samplers.py index 8b9782956..8b530b558 100755 --- a/comfy/samplers.py +++ b/comfy/samplers.py @@ -543,15 +543,23 @@ def resolve_areas_and_cond_masks_multidim(conditions, dims, device): mask = c['mask'] mask = mask.to(device=device) modified = c.copy() - if len(mask.shape) == len(dims): + # Normalize mask to [batch, *spatial_dims] + target_ndim = len(dims) + 1 + while mask.ndim > target_ndim and mask.shape[0] == 1: + mask = mask.squeeze(0) + while mask.ndim < target_ndim: mask = mask.unsqueeze(0) if mask.shape[1:] != dims: - if mask.ndim < 4: + if len(dims) == 1: + mask = torch.nn.functional.interpolate( + mask.unsqueeze(1), size=dims[0], + mode='linear', align_corners=False).squeeze(1) + elif mask.ndim < 4: mask = comfy.utils.common_upscale(mask.unsqueeze(1), dims[-1], dims[-2], 'bilinear', 'none').squeeze(1) else: mask = comfy.utils.common_upscale(mask, dims[-1], dims[-2], 'bilinear', 'none') - if modified.get("set_area_to_bounds", False): #TODO: handle dim != 2 + if modified.get("set_area_to_bounds", False) and len(dims) == 2: bounds = torch.max(torch.abs(mask),dim=0).values.unsqueeze(0) boxes, is_empty = get_mask_aabb(bounds) if is_empty[0]: diff --git a/tests-unit/comfy_test/samplers_test.py b/tests-unit/comfy_test/samplers_test.py new file mode 100644 index 000000000..1325a5c9a --- /dev/null +++ b/tests-unit/comfy_test/samplers_test.py @@ -0,0 +1,230 @@ +""" +Tests for resolve_areas_and_cond_masks_multidim mask normalization and resizing. + +Covers 1D (audio), 2D (image), and 3D (video) spatial dims, including edge cases +around batch size, spurious unsqueeze from upstream nodes, and size mismatches. +""" + +import torch + +from comfy.samplers import resolve_areas_and_cond_masks_multidim + + +def make_cond(mask): + """Create a minimal conditioning dict with a mask.""" + return {"mask": mask, "model_conds": {}} + + +def run_resolve(mask, dims, device="cpu"): + """Run resolve on a single condition and return the resolved mask.""" + conds = [make_cond(mask)] + resolve_areas_and_cond_masks_multidim(conds, dims, device) + return conds[0]["mask"] + + +# ============================================================ +# 1D spatial dims (audio models like AceStep v1.5) +# dims = (length,), expected mask output: [batch, length] +# ============================================================ + +class Test1DSpatial: + """Tests for 1D spatial models (e.g. audio with noise shape [B, C, L]).""" + + def test_correct_shape_same_length(self): + """[B, L] mask with matching length — should pass through unchanged.""" + mask = torch.ones(2, 100) + result = run_resolve(mask, dims=(100,)) + assert result.shape == (2, 100) + + def test_correct_shape_resize(self): + """[B, L] mask with different length — should resize via linear interp.""" + mask = torch.ones(1, 50) + result = run_resolve(mask, dims=(100,)) + assert result.shape == (1, 100) + + def test_bare_spatial_mask(self): + """[L] mask (no batch) — should get batch dim added.""" + mask = torch.ones(50) + result = run_resolve(mask, dims=(100,)) + assert result.shape == (1, 100) + + def test_spurious_unsqueeze_from_hooks(self): + """[1, B, L] mask (from set_mask_for_conditioning unsqueezing a [B, L] mask) + — should squeeze back to [B, L].""" + # Simulates: mask is [B, L], hooks.py does unsqueeze(0) -> [1, B, L] + mask = torch.ones(1, 2, 100) + result = run_resolve(mask, dims=(100,)) + assert result.shape == (2, 100) + + def test_spurious_unsqueeze_batch1(self): + """[1, 1, L] mask (batch=1, hooks added extra dim) — should become [1, L].""" + mask = torch.ones(1, 1, 50) + result = run_resolve(mask, dims=(100,)) + assert result.shape == (1, 100) + + def test_batch_gt1_same_length(self): + """[B, L] mask with batch=4 and matching length — no changes needed.""" + mask = torch.rand(4, 100) + result = run_resolve(mask, dims=(100,)) + assert result.shape == (4, 100) + torch.testing.assert_close(result, mask) + + def test_batch_gt1_resize(self): + """[B, L] mask with batch=4 and different length — should resize each batch.""" + mask = torch.rand(4, 50) + result = run_resolve(mask, dims=(100,)) + assert result.shape == (4, 100) + + def test_values_preserved_no_resize(self): + """Mask values should be preserved when no resize is needed.""" + mask = torch.tensor([[0.0, 0.5, 1.0]]) + result = run_resolve(mask, dims=(3,)) + torch.testing.assert_close(result, mask) + + def test_linear_interpolation_values(self): + """Check that linear interpolation produces sensible values.""" + mask = torch.tensor([[0.0, 1.0]]) # [1, 2] + result = run_resolve(mask, dims=(5,)) + assert result.shape == (1, 5) + # Should interpolate from 0 to 1 + assert result[0, 0].item() < result[0, -1].item() + + def test_set_area_to_bounds_skipped_for_1d(self): + """set_area_to_bounds should be skipped for 1D (no crash).""" + mask = torch.zeros(1, 100) + mask[0, 10:50] = 1.0 + conds = [{"mask": mask, "model_conds": {}, "set_area_to_bounds": True}] + resolve_areas_and_cond_masks_multidim(conds, (100,), "cpu") + assert "area" not in conds[0] + + +# ============================================================ +# 2D spatial dims (image models) — regression tests +# dims = (H, W), expected mask output: [batch, H, W] +# ============================================================ + +class Test2DSpatial: + """Regression tests for standard 2D image models.""" + + def test_correct_shape_same_size(self): + """[B, H, W] mask matching dims — pass through.""" + mask = torch.ones(1, 64, 64) + result = run_resolve(mask, dims=(64, 64)) + assert result.shape == (1, 64, 64) + + def test_bare_spatial_mask(self): + """[H, W] mask — should get batch dim added.""" + mask = torch.ones(64, 64) + result = run_resolve(mask, dims=(64, 64)) + assert result.shape == (1, 64, 64) + + def test_resize_different_resolution(self): + """[B, H1, W1] mask with different size than dims — should bilinear resize.""" + mask = torch.ones(1, 32, 32) + result = run_resolve(mask, dims=(64, 64)) + assert result.shape == (1, 64, 64) + + def test_4d_mask(self): + """[B, C, H, W] mask (4D) — should resize via common_upscale 4D path.""" + mask = torch.ones(1, 1, 32, 32) + result = run_resolve(mask, dims=(64, 64)) + assert result.shape == (1, 64, 64) + + def test_batch_gt1(self): + """[B, H, W] mask with batch > 1.""" + mask = torch.rand(4, 64, 64) + result = run_resolve(mask, dims=(64, 64)) + assert result.shape == (4, 64, 64) + + def test_batch_gt1_resize(self): + """[B, H, W] mask with batch > 1 and different resolution.""" + mask = torch.rand(4, 32, 32) + result = run_resolve(mask, dims=(64, 64)) + assert result.shape == (4, 64, 64) + + def test_set_area_to_bounds(self): + """set_area_to_bounds should work for 2D masks.""" + mask = torch.zeros(1, 64, 64) + mask[0, 10:20, 10:30] = 1.0 + conds = [{"mask": mask, "model_conds": {}, "set_area_to_bounds": True}] + resolve_areas_and_cond_masks_multidim(conds, (64, 64), "cpu") + assert "area" in conds[0] + + def test_non_square_resize(self): + """[B, H1, W1] mask resized to non-square dims.""" + mask = torch.ones(1, 16, 32) + result = run_resolve(mask, dims=(64, 128)) + assert result.shape == (1, 64, 128) + + +# ============================================================ +# 3D spatial dims (video models) +# dims = (T, H, W), expected mask output: [batch, T, H, W] +# ============================================================ + +class Test3DSpatial: + """Tests for 3D spatial models (e.g. video with noise shape [B, C, T, H, W]).""" + + def test_correct_shape_same_size(self): + """[B, T, H, W] mask matching dims — pass through.""" + mask = torch.ones(1, 8, 64, 64) + result = run_resolve(mask, dims=(8, 64, 64)) + assert result.shape == (1, 8, 64, 64) + + def test_bare_spatial_mask(self): + """[T, H, W] mask — should get batch dim added.""" + mask = torch.ones(8, 64, 64) + result = run_resolve(mask, dims=(8, 64, 64)) + assert result.shape == (1, 8, 64, 64) + + def test_resize_hw(self): + """[B, T, H1, W1] mask with different H, W — should resize last 2 dims.""" + mask = torch.ones(1, 8, 32, 32) + result = run_resolve(mask, dims=(8, 64, 64)) + assert result.shape == (1, 8, 64, 64) + + def test_set_area_to_bounds_skipped_for_3d(self): + """set_area_to_bounds should be skipped for 3D (no crash).""" + mask = torch.zeros(1, 8, 64, 64) + mask[0, :, 10:20, 10:30] = 1.0 + conds = [{"mask": mask, "model_conds": {}, "set_area_to_bounds": True}] + resolve_areas_and_cond_masks_multidim(conds, (8, 64, 64), "cpu") + assert "area" not in conds[0] + + +class TestNoMask: + """Conditions without masks should pass through untouched.""" + + def test_no_mask_key(self): + """Condition with no mask key — untouched.""" + conds = [{"model_conds": {}}] + resolve_areas_and_cond_masks_multidim(conds, (64, 64), "cpu") + assert "mask" not in conds[0] + + def test_empty_conditions(self): + """Empty conditions list — no crash.""" + conds = [] + resolve_areas_and_cond_masks_multidim(conds, (64, 64), "cpu") + assert len(conds) == 0 + + +# ============================================================ +# Area resolution (percentage-based) +# ============================================================ + +class TestAreaResolution: + """Test that percentage-based area resolution works for different dims.""" + + def test_percentage_area_2d(self): + """Percentage area for 2D should resolve to pixel coords.""" + conds = [{"area": ("percentage", 0.5, 0.5, 0.25, 0.25), "model_conds": {}}] + resolve_areas_and_cond_masks_multidim(conds, (64, 64), "cpu") + area = conds[0]["area"] + assert area == (32, 32, 16, 16) + + def test_percentage_area_1d(self): + """Percentage area for 1D should resolve to frame coords.""" + conds = [{"area": ("percentage", 0.5, 0.25), "model_conds": {}}] + resolve_areas_and_cond_masks_multidim(conds, (100,), "cpu") + area = conds[0]["area"] + assert area == (50, 25)