ComfyUI/tests-unit/comfy_test/samplers_test.py
RyanOnTheInside abfea891ef Fix conditioning mask normalization for arbitrary spatial dimensions.
May also resolve #9784 — the mask normalization fixes a class of dimensionality mismatches that can cause the `y, x = torch.where(mask)` crash in `get_mask_aabb`, though the root cause in that report is unconfirmed.

## Summary

`resolve_areas_and_cond_masks_multidim` assumes 2D spatial masks. This breaks for 1D audio models (StableAudio1, ACEAudio15) because upstream code (`ConditioningSetMask`, `set_mask_for_conditioning`) unconditionally unsqueezes masks with `ndim < 3`, corrupting valid `[B, L]` masks into `[1, B, L]` before they reach the sampler.

This PR:
- Normalizes masks to `[batch, *spatial_dims]` using `dims` as the source of truth
- Adds a 1D resize path via `F.interpolate(mode='linear')`
- Guards `set_area_to_bounds` with `len(dims) == 2` to prevent crashes on non-2D masks (the existing `get_mask_aabb` and `H, W, Y, X` unpacking are 2D-only)

The root cause is the hardcoded `if len(mask.shape) < 3` in `nodes.py:242` and `hooks.py:725`. Fixing it there would require threading latent dimensionality into the conditioning nodes — a much larger change. Normalizing in `resolve_areas_and_cond_masks_multidim` where `dims` is already available is the minimal fix.

Fully backwards compatible for existing 2D image and 3D video workflows.

## Test plan

- [x] 26 unit tests covering 1D/2D/3D mask normalization, resize, and `set_area_to_bounds` guard (`tests-unit/comfy_test/samplers_test.py`)
- [x] 2D image regression with hook masking: [lorahookmasking.json](https://github.com/Kosinkadink/ComfyUI/blob/workflows/lorahookmasking.json)
- [x] 2D image with `set_area_to_bounds` ("mask bounds" mode) — no crash, correct area computation
- [x] 1D audio with conditioning mask: [acestep-1.5-prompt-lora-blending.json](https://github.com/ryanontheinside/ComfyUI_RyanOnTheInside/blob/main/examples/ace1.5/acestep-1.5-prompt-lora-blending.json) (requires custom nodes that patch this function pending upstream)
2026-02-15 09:45:14 -05:00

231 lines
9.0 KiB
Python

"""
Tests for resolve_areas_and_cond_masks_multidim mask normalization and resizing.
Covers 1D (audio), 2D (image), and 3D (video) spatial dims, including edge cases
around batch size, spurious unsqueeze from upstream nodes, and size mismatches.
"""
import torch
from comfy.samplers import resolve_areas_and_cond_masks_multidim
def make_cond(mask):
"""Create a minimal conditioning dict with a mask."""
return {"mask": mask, "model_conds": {}}
def run_resolve(mask, dims, device="cpu"):
"""Run resolve on a single condition and return the resolved mask."""
conds = [make_cond(mask)]
resolve_areas_and_cond_masks_multidim(conds, dims, device)
return conds[0]["mask"]
# ============================================================
# 1D spatial dims (audio models like AceStep v1.5)
# dims = (length,), expected mask output: [batch, length]
# ============================================================
class Test1DSpatial:
"""Tests for 1D spatial models (e.g. audio with noise shape [B, C, L])."""
def test_correct_shape_same_length(self):
"""[B, L] mask with matching length — should pass through unchanged."""
mask = torch.ones(2, 100)
result = run_resolve(mask, dims=(100,))
assert result.shape == (2, 100)
def test_correct_shape_resize(self):
"""[B, L] mask with different length — should resize via linear interp."""
mask = torch.ones(1, 50)
result = run_resolve(mask, dims=(100,))
assert result.shape == (1, 100)
def test_bare_spatial_mask(self):
"""[L] mask (no batch) — should get batch dim added."""
mask = torch.ones(50)
result = run_resolve(mask, dims=(100,))
assert result.shape == (1, 100)
def test_spurious_unsqueeze_from_hooks(self):
"""[1, B, L] mask (from set_mask_for_conditioning unsqueezing a [B, L] mask)
— should squeeze back to [B, L]."""
# Simulates: mask is [B, L], hooks.py does unsqueeze(0) -> [1, B, L]
mask = torch.ones(1, 2, 100)
result = run_resolve(mask, dims=(100,))
assert result.shape == (2, 100)
def test_spurious_unsqueeze_batch1(self):
"""[1, 1, L] mask (batch=1, hooks added extra dim) — should become [1, L]."""
mask = torch.ones(1, 1, 50)
result = run_resolve(mask, dims=(100,))
assert result.shape == (1, 100)
def test_batch_gt1_same_length(self):
"""[B, L] mask with batch=4 and matching length — no changes needed."""
mask = torch.rand(4, 100)
result = run_resolve(mask, dims=(100,))
assert result.shape == (4, 100)
torch.testing.assert_close(result, mask)
def test_batch_gt1_resize(self):
"""[B, L] mask with batch=4 and different length — should resize each batch."""
mask = torch.rand(4, 50)
result = run_resolve(mask, dims=(100,))
assert result.shape == (4, 100)
def test_values_preserved_no_resize(self):
"""Mask values should be preserved when no resize is needed."""
mask = torch.tensor([[0.0, 0.5, 1.0]])
result = run_resolve(mask, dims=(3,))
torch.testing.assert_close(result, mask)
def test_linear_interpolation_values(self):
"""Check that linear interpolation produces sensible values."""
mask = torch.tensor([[0.0, 1.0]]) # [1, 2]
result = run_resolve(mask, dims=(5,))
assert result.shape == (1, 5)
# Should interpolate from 0 to 1
assert result[0, 0].item() < result[0, -1].item()
def test_set_area_to_bounds_skipped_for_1d(self):
"""set_area_to_bounds should be skipped for 1D (no crash)."""
mask = torch.zeros(1, 100)
mask[0, 10:50] = 1.0
conds = [{"mask": mask, "model_conds": {}, "set_area_to_bounds": True}]
resolve_areas_and_cond_masks_multidim(conds, (100,), "cpu")
assert "area" not in conds[0]
# ============================================================
# 2D spatial dims (image models) — regression tests
# dims = (H, W), expected mask output: [batch, H, W]
# ============================================================
class Test2DSpatial:
"""Regression tests for standard 2D image models."""
def test_correct_shape_same_size(self):
"""[B, H, W] mask matching dims — pass through."""
mask = torch.ones(1, 64, 64)
result = run_resolve(mask, dims=(64, 64))
assert result.shape == (1, 64, 64)
def test_bare_spatial_mask(self):
"""[H, W] mask — should get batch dim added."""
mask = torch.ones(64, 64)
result = run_resolve(mask, dims=(64, 64))
assert result.shape == (1, 64, 64)
def test_resize_different_resolution(self):
"""[B, H1, W1] mask with different size than dims — should bilinear resize."""
mask = torch.ones(1, 32, 32)
result = run_resolve(mask, dims=(64, 64))
assert result.shape == (1, 64, 64)
def test_4d_mask(self):
"""[B, C, H, W] mask (4D) — should resize via common_upscale 4D path."""
mask = torch.ones(1, 1, 32, 32)
result = run_resolve(mask, dims=(64, 64))
assert result.shape == (1, 64, 64)
def test_batch_gt1(self):
"""[B, H, W] mask with batch > 1."""
mask = torch.rand(4, 64, 64)
result = run_resolve(mask, dims=(64, 64))
assert result.shape == (4, 64, 64)
def test_batch_gt1_resize(self):
"""[B, H, W] mask with batch > 1 and different resolution."""
mask = torch.rand(4, 32, 32)
result = run_resolve(mask, dims=(64, 64))
assert result.shape == (4, 64, 64)
def test_set_area_to_bounds(self):
"""set_area_to_bounds should work for 2D masks."""
mask = torch.zeros(1, 64, 64)
mask[0, 10:20, 10:30] = 1.0
conds = [{"mask": mask, "model_conds": {}, "set_area_to_bounds": True}]
resolve_areas_and_cond_masks_multidim(conds, (64, 64), "cpu")
assert "area" in conds[0]
def test_non_square_resize(self):
"""[B, H1, W1] mask resized to non-square dims."""
mask = torch.ones(1, 16, 32)
result = run_resolve(mask, dims=(64, 128))
assert result.shape == (1, 64, 128)
# ============================================================
# 3D spatial dims (video models)
# dims = (T, H, W), expected mask output: [batch, T, H, W]
# ============================================================
class Test3DSpatial:
"""Tests for 3D spatial models (e.g. video with noise shape [B, C, T, H, W])."""
def test_correct_shape_same_size(self):
"""[B, T, H, W] mask matching dims — pass through."""
mask = torch.ones(1, 8, 64, 64)
result = run_resolve(mask, dims=(8, 64, 64))
assert result.shape == (1, 8, 64, 64)
def test_bare_spatial_mask(self):
"""[T, H, W] mask — should get batch dim added."""
mask = torch.ones(8, 64, 64)
result = run_resolve(mask, dims=(8, 64, 64))
assert result.shape == (1, 8, 64, 64)
def test_resize_hw(self):
"""[B, T, H1, W1] mask with different H, W — should resize last 2 dims."""
mask = torch.ones(1, 8, 32, 32)
result = run_resolve(mask, dims=(8, 64, 64))
assert result.shape == (1, 8, 64, 64)
def test_set_area_to_bounds_skipped_for_3d(self):
"""set_area_to_bounds should be skipped for 3D (no crash)."""
mask = torch.zeros(1, 8, 64, 64)
mask[0, :, 10:20, 10:30] = 1.0
conds = [{"mask": mask, "model_conds": {}, "set_area_to_bounds": True}]
resolve_areas_and_cond_masks_multidim(conds, (8, 64, 64), "cpu")
assert "area" not in conds[0]
class TestNoMask:
"""Conditions without masks should pass through untouched."""
def test_no_mask_key(self):
"""Condition with no mask key — untouched."""
conds = [{"model_conds": {}}]
resolve_areas_and_cond_masks_multidim(conds, (64, 64), "cpu")
assert "mask" not in conds[0]
def test_empty_conditions(self):
"""Empty conditions list — no crash."""
conds = []
resolve_areas_and_cond_masks_multidim(conds, (64, 64), "cpu")
assert len(conds) == 0
# ============================================================
# Area resolution (percentage-based)
# ============================================================
class TestAreaResolution:
"""Test that percentage-based area resolution works for different dims."""
def test_percentage_area_2d(self):
"""Percentage area for 2D should resolve to pixel coords."""
conds = [{"area": ("percentage", 0.5, 0.5, 0.25, 0.25), "model_conds": {}}]
resolve_areas_and_cond_masks_multidim(conds, (64, 64), "cpu")
area = conds[0]["area"]
assert area == (32, 32, 16, 16)
def test_percentage_area_1d(self):
"""Percentage area for 1D should resolve to frame coords."""
conds = [{"area": ("percentage", 0.5, 0.25), "model_conds": {}}]
resolve_areas_and_cond_masks_multidim(conds, (100,), "cpu")
area = conds[0]["area"]
assert area == (50, 25)