mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-30 21:43:43 +08:00
May also resolve #9784 — the mask normalization fixes a class of dimensionality mismatches that can cause the `y, x = torch.where(mask)` crash in `get_mask_aabb`, though the root cause in that report is unconfirmed. ## Summary `resolve_areas_and_cond_masks_multidim` assumes 2D spatial masks. This breaks for 1D audio models (StableAudio1, ACEAudio15) because upstream code (`ConditioningSetMask`, `set_mask_for_conditioning`) unconditionally unsqueezes masks with `ndim < 3`, corrupting valid `[B, L]` masks into `[1, B, L]` before they reach the sampler. This PR: - Normalizes masks to `[batch, *spatial_dims]` using `dims` as the source of truth - Adds a 1D resize path via `F.interpolate(mode='linear')` - Guards `set_area_to_bounds` with `len(dims) == 2` to prevent crashes on non-2D masks (the existing `get_mask_aabb` and `H, W, Y, X` unpacking are 2D-only) The root cause is the hardcoded `if len(mask.shape) < 3` in `nodes.py:242` and `hooks.py:725`. Fixing it there would require threading latent dimensionality into the conditioning nodes — a much larger change. Normalizing in `resolve_areas_and_cond_masks_multidim` where `dims` is already available is the minimal fix. Fully backwards compatible for existing 2D image and 3D video workflows. ## Test plan - [x] 26 unit tests covering 1D/2D/3D mask normalization, resize, and `set_area_to_bounds` guard (`tests-unit/comfy_test/samplers_test.py`) - [x] 2D image regression with hook masking: [lorahookmasking.json](https://github.com/Kosinkadink/ComfyUI/blob/workflows/lorahookmasking.json) - [x] 2D image with `set_area_to_bounds` ("mask bounds" mode) — no crash, correct area computation - [x] 1D audio with conditioning mask: [acestep-1.5-prompt-lora-blending.json](https://github.com/ryanontheinside/ComfyUI_RyanOnTheInside/blob/main/examples/ace1.5/acestep-1.5-prompt-lora-blending.json) (requires custom nodes that patch this function pending upstream)
231 lines
9.0 KiB
Python
231 lines
9.0 KiB
Python
"""
|
|
Tests for resolve_areas_and_cond_masks_multidim mask normalization and resizing.
|
|
|
|
Covers 1D (audio), 2D (image), and 3D (video) spatial dims, including edge cases
|
|
around batch size, spurious unsqueeze from upstream nodes, and size mismatches.
|
|
"""
|
|
|
|
import torch
|
|
|
|
from comfy.samplers import resolve_areas_and_cond_masks_multidim
|
|
|
|
|
|
def make_cond(mask):
|
|
"""Create a minimal conditioning dict with a mask."""
|
|
return {"mask": mask, "model_conds": {}}
|
|
|
|
|
|
def run_resolve(mask, dims, device="cpu"):
|
|
"""Run resolve on a single condition and return the resolved mask."""
|
|
conds = [make_cond(mask)]
|
|
resolve_areas_and_cond_masks_multidim(conds, dims, device)
|
|
return conds[0]["mask"]
|
|
|
|
|
|
# ============================================================
|
|
# 1D spatial dims (audio models like AceStep v1.5)
|
|
# dims = (length,), expected mask output: [batch, length]
|
|
# ============================================================
|
|
|
|
class Test1DSpatial:
|
|
"""Tests for 1D spatial models (e.g. audio with noise shape [B, C, L])."""
|
|
|
|
def test_correct_shape_same_length(self):
|
|
"""[B, L] mask with matching length — should pass through unchanged."""
|
|
mask = torch.ones(2, 100)
|
|
result = run_resolve(mask, dims=(100,))
|
|
assert result.shape == (2, 100)
|
|
|
|
def test_correct_shape_resize(self):
|
|
"""[B, L] mask with different length — should resize via linear interp."""
|
|
mask = torch.ones(1, 50)
|
|
result = run_resolve(mask, dims=(100,))
|
|
assert result.shape == (1, 100)
|
|
|
|
def test_bare_spatial_mask(self):
|
|
"""[L] mask (no batch) — should get batch dim added."""
|
|
mask = torch.ones(50)
|
|
result = run_resolve(mask, dims=(100,))
|
|
assert result.shape == (1, 100)
|
|
|
|
def test_spurious_unsqueeze_from_hooks(self):
|
|
"""[1, B, L] mask (from set_mask_for_conditioning unsqueezing a [B, L] mask)
|
|
— should squeeze back to [B, L]."""
|
|
# Simulates: mask is [B, L], hooks.py does unsqueeze(0) -> [1, B, L]
|
|
mask = torch.ones(1, 2, 100)
|
|
result = run_resolve(mask, dims=(100,))
|
|
assert result.shape == (2, 100)
|
|
|
|
def test_spurious_unsqueeze_batch1(self):
|
|
"""[1, 1, L] mask (batch=1, hooks added extra dim) — should become [1, L]."""
|
|
mask = torch.ones(1, 1, 50)
|
|
result = run_resolve(mask, dims=(100,))
|
|
assert result.shape == (1, 100)
|
|
|
|
def test_batch_gt1_same_length(self):
|
|
"""[B, L] mask with batch=4 and matching length — no changes needed."""
|
|
mask = torch.rand(4, 100)
|
|
result = run_resolve(mask, dims=(100,))
|
|
assert result.shape == (4, 100)
|
|
torch.testing.assert_close(result, mask)
|
|
|
|
def test_batch_gt1_resize(self):
|
|
"""[B, L] mask with batch=4 and different length — should resize each batch."""
|
|
mask = torch.rand(4, 50)
|
|
result = run_resolve(mask, dims=(100,))
|
|
assert result.shape == (4, 100)
|
|
|
|
def test_values_preserved_no_resize(self):
|
|
"""Mask values should be preserved when no resize is needed."""
|
|
mask = torch.tensor([[0.0, 0.5, 1.0]])
|
|
result = run_resolve(mask, dims=(3,))
|
|
torch.testing.assert_close(result, mask)
|
|
|
|
def test_linear_interpolation_values(self):
|
|
"""Check that linear interpolation produces sensible values."""
|
|
mask = torch.tensor([[0.0, 1.0]]) # [1, 2]
|
|
result = run_resolve(mask, dims=(5,))
|
|
assert result.shape == (1, 5)
|
|
# Should interpolate from 0 to 1
|
|
assert result[0, 0].item() < result[0, -1].item()
|
|
|
|
def test_set_area_to_bounds_skipped_for_1d(self):
|
|
"""set_area_to_bounds should be skipped for 1D (no crash)."""
|
|
mask = torch.zeros(1, 100)
|
|
mask[0, 10:50] = 1.0
|
|
conds = [{"mask": mask, "model_conds": {}, "set_area_to_bounds": True}]
|
|
resolve_areas_and_cond_masks_multidim(conds, (100,), "cpu")
|
|
assert "area" not in conds[0]
|
|
|
|
|
|
# ============================================================
|
|
# 2D spatial dims (image models) — regression tests
|
|
# dims = (H, W), expected mask output: [batch, H, W]
|
|
# ============================================================
|
|
|
|
class Test2DSpatial:
|
|
"""Regression tests for standard 2D image models."""
|
|
|
|
def test_correct_shape_same_size(self):
|
|
"""[B, H, W] mask matching dims — pass through."""
|
|
mask = torch.ones(1, 64, 64)
|
|
result = run_resolve(mask, dims=(64, 64))
|
|
assert result.shape == (1, 64, 64)
|
|
|
|
def test_bare_spatial_mask(self):
|
|
"""[H, W] mask — should get batch dim added."""
|
|
mask = torch.ones(64, 64)
|
|
result = run_resolve(mask, dims=(64, 64))
|
|
assert result.shape == (1, 64, 64)
|
|
|
|
def test_resize_different_resolution(self):
|
|
"""[B, H1, W1] mask with different size than dims — should bilinear resize."""
|
|
mask = torch.ones(1, 32, 32)
|
|
result = run_resolve(mask, dims=(64, 64))
|
|
assert result.shape == (1, 64, 64)
|
|
|
|
def test_4d_mask(self):
|
|
"""[B, C, H, W] mask (4D) — should resize via common_upscale 4D path."""
|
|
mask = torch.ones(1, 1, 32, 32)
|
|
result = run_resolve(mask, dims=(64, 64))
|
|
assert result.shape == (1, 64, 64)
|
|
|
|
def test_batch_gt1(self):
|
|
"""[B, H, W] mask with batch > 1."""
|
|
mask = torch.rand(4, 64, 64)
|
|
result = run_resolve(mask, dims=(64, 64))
|
|
assert result.shape == (4, 64, 64)
|
|
|
|
def test_batch_gt1_resize(self):
|
|
"""[B, H, W] mask with batch > 1 and different resolution."""
|
|
mask = torch.rand(4, 32, 32)
|
|
result = run_resolve(mask, dims=(64, 64))
|
|
assert result.shape == (4, 64, 64)
|
|
|
|
def test_set_area_to_bounds(self):
|
|
"""set_area_to_bounds should work for 2D masks."""
|
|
mask = torch.zeros(1, 64, 64)
|
|
mask[0, 10:20, 10:30] = 1.0
|
|
conds = [{"mask": mask, "model_conds": {}, "set_area_to_bounds": True}]
|
|
resolve_areas_and_cond_masks_multidim(conds, (64, 64), "cpu")
|
|
assert "area" in conds[0]
|
|
|
|
def test_non_square_resize(self):
|
|
"""[B, H1, W1] mask resized to non-square dims."""
|
|
mask = torch.ones(1, 16, 32)
|
|
result = run_resolve(mask, dims=(64, 128))
|
|
assert result.shape == (1, 64, 128)
|
|
|
|
|
|
# ============================================================
|
|
# 3D spatial dims (video models)
|
|
# dims = (T, H, W), expected mask output: [batch, T, H, W]
|
|
# ============================================================
|
|
|
|
class Test3DSpatial:
|
|
"""Tests for 3D spatial models (e.g. video with noise shape [B, C, T, H, W])."""
|
|
|
|
def test_correct_shape_same_size(self):
|
|
"""[B, T, H, W] mask matching dims — pass through."""
|
|
mask = torch.ones(1, 8, 64, 64)
|
|
result = run_resolve(mask, dims=(8, 64, 64))
|
|
assert result.shape == (1, 8, 64, 64)
|
|
|
|
def test_bare_spatial_mask(self):
|
|
"""[T, H, W] mask — should get batch dim added."""
|
|
mask = torch.ones(8, 64, 64)
|
|
result = run_resolve(mask, dims=(8, 64, 64))
|
|
assert result.shape == (1, 8, 64, 64)
|
|
|
|
def test_resize_hw(self):
|
|
"""[B, T, H1, W1] mask with different H, W — should resize last 2 dims."""
|
|
mask = torch.ones(1, 8, 32, 32)
|
|
result = run_resolve(mask, dims=(8, 64, 64))
|
|
assert result.shape == (1, 8, 64, 64)
|
|
|
|
def test_set_area_to_bounds_skipped_for_3d(self):
|
|
"""set_area_to_bounds should be skipped for 3D (no crash)."""
|
|
mask = torch.zeros(1, 8, 64, 64)
|
|
mask[0, :, 10:20, 10:30] = 1.0
|
|
conds = [{"mask": mask, "model_conds": {}, "set_area_to_bounds": True}]
|
|
resolve_areas_and_cond_masks_multidim(conds, (8, 64, 64), "cpu")
|
|
assert "area" not in conds[0]
|
|
|
|
|
|
class TestNoMask:
|
|
"""Conditions without masks should pass through untouched."""
|
|
|
|
def test_no_mask_key(self):
|
|
"""Condition with no mask key — untouched."""
|
|
conds = [{"model_conds": {}}]
|
|
resolve_areas_and_cond_masks_multidim(conds, (64, 64), "cpu")
|
|
assert "mask" not in conds[0]
|
|
|
|
def test_empty_conditions(self):
|
|
"""Empty conditions list — no crash."""
|
|
conds = []
|
|
resolve_areas_and_cond_masks_multidim(conds, (64, 64), "cpu")
|
|
assert len(conds) == 0
|
|
|
|
|
|
# ============================================================
|
|
# Area resolution (percentage-based)
|
|
# ============================================================
|
|
|
|
class TestAreaResolution:
|
|
"""Test that percentage-based area resolution works for different dims."""
|
|
|
|
def test_percentage_area_2d(self):
|
|
"""Percentage area for 2D should resolve to pixel coords."""
|
|
conds = [{"area": ("percentage", 0.5, 0.5, 0.25, 0.25), "model_conds": {}}]
|
|
resolve_areas_and_cond_masks_multidim(conds, (64, 64), "cpu")
|
|
area = conds[0]["area"]
|
|
assert area == (32, 32, 16, 16)
|
|
|
|
def test_percentage_area_1d(self):
|
|
"""Percentage area for 1D should resolve to frame coords."""
|
|
conds = [{"area": ("percentage", 0.5, 0.25), "model_conds": {}}]
|
|
resolve_areas_and_cond_masks_multidim(conds, (100,), "cpu")
|
|
area = conds[0]["area"]
|
|
assert area == (50, 25)
|