mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-30 13:33:42 +08:00
Fix conditioning mask normalization for arbitrary spatial dimensions.
May also resolve #9784 — the mask normalization fixes a class of dimensionality mismatches that can cause the `y, x = torch.where(mask)` crash in `get_mask_aabb`, though the root cause in that report is unconfirmed. ## Summary `resolve_areas_and_cond_masks_multidim` assumes 2D spatial masks. This breaks for 1D audio models (StableAudio1, ACEAudio15) because upstream code (`ConditioningSetMask`, `set_mask_for_conditioning`) unconditionally unsqueezes masks with `ndim < 3`, corrupting valid `[B, L]` masks into `[1, B, L]` before they reach the sampler. This PR: - Normalizes masks to `[batch, *spatial_dims]` using `dims` as the source of truth - Adds a 1D resize path via `F.interpolate(mode='linear')` - Guards `set_area_to_bounds` with `len(dims) == 2` to prevent crashes on non-2D masks (the existing `get_mask_aabb` and `H, W, Y, X` unpacking are 2D-only) The root cause is the hardcoded `if len(mask.shape) < 3` in `nodes.py:242` and `hooks.py:725`. Fixing it there would require threading latent dimensionality into the conditioning nodes — a much larger change. Normalizing in `resolve_areas_and_cond_masks_multidim` where `dims` is already available is the minimal fix. Fully backwards compatible for existing 2D image and 3D video workflows. ## Test plan - [x] 26 unit tests covering 1D/2D/3D mask normalization, resize, and `set_area_to_bounds` guard (`tests-unit/comfy_test/samplers_test.py`) - [x] 2D image regression with hook masking: [lorahookmasking.json](https://github.com/Kosinkadink/ComfyUI/blob/workflows/lorahookmasking.json) - [x] 2D image with `set_area_to_bounds` ("mask bounds" mode) — no crash, correct area computation - [x] 1D audio with conditioning mask: [acestep-1.5-prompt-lora-blending.json](https://github.com/ryanontheinside/ComfyUI_RyanOnTheInside/blob/main/examples/ace1.5/acestep-1.5-prompt-lora-blending.json) (requires custom nodes that patch this function pending upstream)
This commit is contained in:
parent
e2c71ceb00
commit
abfea891ef
@ -543,15 +543,23 @@ def resolve_areas_and_cond_masks_multidim(conditions, dims, device):
|
||||
mask = c['mask']
|
||||
mask = mask.to(device=device)
|
||||
modified = c.copy()
|
||||
if len(mask.shape) == len(dims):
|
||||
# Normalize mask to [batch, *spatial_dims]
|
||||
target_ndim = len(dims) + 1
|
||||
while mask.ndim > target_ndim and mask.shape[0] == 1:
|
||||
mask = mask.squeeze(0)
|
||||
while mask.ndim < target_ndim:
|
||||
mask = mask.unsqueeze(0)
|
||||
if mask.shape[1:] != dims:
|
||||
if mask.ndim < 4:
|
||||
if len(dims) == 1:
|
||||
mask = torch.nn.functional.interpolate(
|
||||
mask.unsqueeze(1), size=dims[0],
|
||||
mode='linear', align_corners=False).squeeze(1)
|
||||
elif mask.ndim < 4:
|
||||
mask = comfy.utils.common_upscale(mask.unsqueeze(1), dims[-1], dims[-2], 'bilinear', 'none').squeeze(1)
|
||||
else:
|
||||
mask = comfy.utils.common_upscale(mask, dims[-1], dims[-2], 'bilinear', 'none')
|
||||
|
||||
if modified.get("set_area_to_bounds", False): #TODO: handle dim != 2
|
||||
if modified.get("set_area_to_bounds", False) and len(dims) == 2:
|
||||
bounds = torch.max(torch.abs(mask),dim=0).values.unsqueeze(0)
|
||||
boxes, is_empty = get_mask_aabb(bounds)
|
||||
if is_empty[0]:
|
||||
|
||||
230
tests-unit/comfy_test/samplers_test.py
Normal file
230
tests-unit/comfy_test/samplers_test.py
Normal file
@ -0,0 +1,230 @@
|
||||
"""
|
||||
Tests for resolve_areas_and_cond_masks_multidim mask normalization and resizing.
|
||||
|
||||
Covers 1D (audio), 2D (image), and 3D (video) spatial dims, including edge cases
|
||||
around batch size, spurious unsqueeze from upstream nodes, and size mismatches.
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
from comfy.samplers import resolve_areas_and_cond_masks_multidim
|
||||
|
||||
|
||||
def make_cond(mask):
|
||||
"""Create a minimal conditioning dict with a mask."""
|
||||
return {"mask": mask, "model_conds": {}}
|
||||
|
||||
|
||||
def run_resolve(mask, dims, device="cpu"):
|
||||
"""Run resolve on a single condition and return the resolved mask."""
|
||||
conds = [make_cond(mask)]
|
||||
resolve_areas_and_cond_masks_multidim(conds, dims, device)
|
||||
return conds[0]["mask"]
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 1D spatial dims (audio models like AceStep v1.5)
|
||||
# dims = (length,), expected mask output: [batch, length]
|
||||
# ============================================================
|
||||
|
||||
class Test1DSpatial:
|
||||
"""Tests for 1D spatial models (e.g. audio with noise shape [B, C, L])."""
|
||||
|
||||
def test_correct_shape_same_length(self):
|
||||
"""[B, L] mask with matching length — should pass through unchanged."""
|
||||
mask = torch.ones(2, 100)
|
||||
result = run_resolve(mask, dims=(100,))
|
||||
assert result.shape == (2, 100)
|
||||
|
||||
def test_correct_shape_resize(self):
|
||||
"""[B, L] mask with different length — should resize via linear interp."""
|
||||
mask = torch.ones(1, 50)
|
||||
result = run_resolve(mask, dims=(100,))
|
||||
assert result.shape == (1, 100)
|
||||
|
||||
def test_bare_spatial_mask(self):
|
||||
"""[L] mask (no batch) — should get batch dim added."""
|
||||
mask = torch.ones(50)
|
||||
result = run_resolve(mask, dims=(100,))
|
||||
assert result.shape == (1, 100)
|
||||
|
||||
def test_spurious_unsqueeze_from_hooks(self):
|
||||
"""[1, B, L] mask (from set_mask_for_conditioning unsqueezing a [B, L] mask)
|
||||
— should squeeze back to [B, L]."""
|
||||
# Simulates: mask is [B, L], hooks.py does unsqueeze(0) -> [1, B, L]
|
||||
mask = torch.ones(1, 2, 100)
|
||||
result = run_resolve(mask, dims=(100,))
|
||||
assert result.shape == (2, 100)
|
||||
|
||||
def test_spurious_unsqueeze_batch1(self):
|
||||
"""[1, 1, L] mask (batch=1, hooks added extra dim) — should become [1, L]."""
|
||||
mask = torch.ones(1, 1, 50)
|
||||
result = run_resolve(mask, dims=(100,))
|
||||
assert result.shape == (1, 100)
|
||||
|
||||
def test_batch_gt1_same_length(self):
|
||||
"""[B, L] mask with batch=4 and matching length — no changes needed."""
|
||||
mask = torch.rand(4, 100)
|
||||
result = run_resolve(mask, dims=(100,))
|
||||
assert result.shape == (4, 100)
|
||||
torch.testing.assert_close(result, mask)
|
||||
|
||||
def test_batch_gt1_resize(self):
|
||||
"""[B, L] mask with batch=4 and different length — should resize each batch."""
|
||||
mask = torch.rand(4, 50)
|
||||
result = run_resolve(mask, dims=(100,))
|
||||
assert result.shape == (4, 100)
|
||||
|
||||
def test_values_preserved_no_resize(self):
|
||||
"""Mask values should be preserved when no resize is needed."""
|
||||
mask = torch.tensor([[0.0, 0.5, 1.0]])
|
||||
result = run_resolve(mask, dims=(3,))
|
||||
torch.testing.assert_close(result, mask)
|
||||
|
||||
def test_linear_interpolation_values(self):
|
||||
"""Check that linear interpolation produces sensible values."""
|
||||
mask = torch.tensor([[0.0, 1.0]]) # [1, 2]
|
||||
result = run_resolve(mask, dims=(5,))
|
||||
assert result.shape == (1, 5)
|
||||
# Should interpolate from 0 to 1
|
||||
assert result[0, 0].item() < result[0, -1].item()
|
||||
|
||||
def test_set_area_to_bounds_skipped_for_1d(self):
|
||||
"""set_area_to_bounds should be skipped for 1D (no crash)."""
|
||||
mask = torch.zeros(1, 100)
|
||||
mask[0, 10:50] = 1.0
|
||||
conds = [{"mask": mask, "model_conds": {}, "set_area_to_bounds": True}]
|
||||
resolve_areas_and_cond_masks_multidim(conds, (100,), "cpu")
|
||||
assert "area" not in conds[0]
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 2D spatial dims (image models) — regression tests
|
||||
# dims = (H, W), expected mask output: [batch, H, W]
|
||||
# ============================================================
|
||||
|
||||
class Test2DSpatial:
|
||||
"""Regression tests for standard 2D image models."""
|
||||
|
||||
def test_correct_shape_same_size(self):
|
||||
"""[B, H, W] mask matching dims — pass through."""
|
||||
mask = torch.ones(1, 64, 64)
|
||||
result = run_resolve(mask, dims=(64, 64))
|
||||
assert result.shape == (1, 64, 64)
|
||||
|
||||
def test_bare_spatial_mask(self):
|
||||
"""[H, W] mask — should get batch dim added."""
|
||||
mask = torch.ones(64, 64)
|
||||
result = run_resolve(mask, dims=(64, 64))
|
||||
assert result.shape == (1, 64, 64)
|
||||
|
||||
def test_resize_different_resolution(self):
|
||||
"""[B, H1, W1] mask with different size than dims — should bilinear resize."""
|
||||
mask = torch.ones(1, 32, 32)
|
||||
result = run_resolve(mask, dims=(64, 64))
|
||||
assert result.shape == (1, 64, 64)
|
||||
|
||||
def test_4d_mask(self):
|
||||
"""[B, C, H, W] mask (4D) — should resize via common_upscale 4D path."""
|
||||
mask = torch.ones(1, 1, 32, 32)
|
||||
result = run_resolve(mask, dims=(64, 64))
|
||||
assert result.shape == (1, 64, 64)
|
||||
|
||||
def test_batch_gt1(self):
|
||||
"""[B, H, W] mask with batch > 1."""
|
||||
mask = torch.rand(4, 64, 64)
|
||||
result = run_resolve(mask, dims=(64, 64))
|
||||
assert result.shape == (4, 64, 64)
|
||||
|
||||
def test_batch_gt1_resize(self):
|
||||
"""[B, H, W] mask with batch > 1 and different resolution."""
|
||||
mask = torch.rand(4, 32, 32)
|
||||
result = run_resolve(mask, dims=(64, 64))
|
||||
assert result.shape == (4, 64, 64)
|
||||
|
||||
def test_set_area_to_bounds(self):
|
||||
"""set_area_to_bounds should work for 2D masks."""
|
||||
mask = torch.zeros(1, 64, 64)
|
||||
mask[0, 10:20, 10:30] = 1.0
|
||||
conds = [{"mask": mask, "model_conds": {}, "set_area_to_bounds": True}]
|
||||
resolve_areas_and_cond_masks_multidim(conds, (64, 64), "cpu")
|
||||
assert "area" in conds[0]
|
||||
|
||||
def test_non_square_resize(self):
|
||||
"""[B, H1, W1] mask resized to non-square dims."""
|
||||
mask = torch.ones(1, 16, 32)
|
||||
result = run_resolve(mask, dims=(64, 128))
|
||||
assert result.shape == (1, 64, 128)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# 3D spatial dims (video models)
|
||||
# dims = (T, H, W), expected mask output: [batch, T, H, W]
|
||||
# ============================================================
|
||||
|
||||
class Test3DSpatial:
|
||||
"""Tests for 3D spatial models (e.g. video with noise shape [B, C, T, H, W])."""
|
||||
|
||||
def test_correct_shape_same_size(self):
|
||||
"""[B, T, H, W] mask matching dims — pass through."""
|
||||
mask = torch.ones(1, 8, 64, 64)
|
||||
result = run_resolve(mask, dims=(8, 64, 64))
|
||||
assert result.shape == (1, 8, 64, 64)
|
||||
|
||||
def test_bare_spatial_mask(self):
|
||||
"""[T, H, W] mask — should get batch dim added."""
|
||||
mask = torch.ones(8, 64, 64)
|
||||
result = run_resolve(mask, dims=(8, 64, 64))
|
||||
assert result.shape == (1, 8, 64, 64)
|
||||
|
||||
def test_resize_hw(self):
|
||||
"""[B, T, H1, W1] mask with different H, W — should resize last 2 dims."""
|
||||
mask = torch.ones(1, 8, 32, 32)
|
||||
result = run_resolve(mask, dims=(8, 64, 64))
|
||||
assert result.shape == (1, 8, 64, 64)
|
||||
|
||||
def test_set_area_to_bounds_skipped_for_3d(self):
|
||||
"""set_area_to_bounds should be skipped for 3D (no crash)."""
|
||||
mask = torch.zeros(1, 8, 64, 64)
|
||||
mask[0, :, 10:20, 10:30] = 1.0
|
||||
conds = [{"mask": mask, "model_conds": {}, "set_area_to_bounds": True}]
|
||||
resolve_areas_and_cond_masks_multidim(conds, (8, 64, 64), "cpu")
|
||||
assert "area" not in conds[0]
|
||||
|
||||
|
||||
class TestNoMask:
|
||||
"""Conditions without masks should pass through untouched."""
|
||||
|
||||
def test_no_mask_key(self):
|
||||
"""Condition with no mask key — untouched."""
|
||||
conds = [{"model_conds": {}}]
|
||||
resolve_areas_and_cond_masks_multidim(conds, (64, 64), "cpu")
|
||||
assert "mask" not in conds[0]
|
||||
|
||||
def test_empty_conditions(self):
|
||||
"""Empty conditions list — no crash."""
|
||||
conds = []
|
||||
resolve_areas_and_cond_masks_multidim(conds, (64, 64), "cpu")
|
||||
assert len(conds) == 0
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Area resolution (percentage-based)
|
||||
# ============================================================
|
||||
|
||||
class TestAreaResolution:
|
||||
"""Test that percentage-based area resolution works for different dims."""
|
||||
|
||||
def test_percentage_area_2d(self):
|
||||
"""Percentage area for 2D should resolve to pixel coords."""
|
||||
conds = [{"area": ("percentage", 0.5, 0.5, 0.25, 0.25), "model_conds": {}}]
|
||||
resolve_areas_and_cond_masks_multidim(conds, (64, 64), "cpu")
|
||||
area = conds[0]["area"]
|
||||
assert area == (32, 32, 16, 16)
|
||||
|
||||
def test_percentage_area_1d(self):
|
||||
"""Percentage area for 1D should resolve to frame coords."""
|
||||
conds = [{"area": ("percentage", 0.5, 0.25), "model_conds": {}}]
|
||||
resolve_areas_and_cond_masks_multidim(conds, (100,), "cpu")
|
||||
area = conds[0]["area"]
|
||||
assert area == (50, 25)
|
||||
Loading…
Reference in New Issue
Block a user