Fix reshape_mask for 1D spatial dimensions.

reshape_mask sets scale_mode="linear" for dims==1 but is missing the input reshape to [N, C, W] that the 2D and 3D branches both perform. Add the missing reshape, matching the existing pattern.
This commit is contained in:
RyanOnTheInside 2026-02-15 10:53:49 -05:00
parent abfea891ef
commit 9a5e6233f4
2 changed files with 110 additions and 3 deletions

View File

@ -1241,6 +1241,7 @@ def reshape_mask(input_mask, output_shape):
dims = len(output_shape) - 2
if dims == 1:
input_mask = input_mask.reshape((-1, 1, input_mask.shape[-1]))
scale_mode = "linear"
if dims == 2:

View File

@ -1,13 +1,16 @@
"""
Tests for resolve_areas_and_cond_masks_multidim mask normalization and resizing.
Tests for mask handling across arbitrary spatial dimensions.
Covers 1D (audio), 2D (image), and 3D (video) spatial dims, including edge cases
around batch size, spurious unsqueeze from upstream nodes, and size mismatches.
Covers resolve_areas_and_cond_masks_multidim (conditioning masks) and
reshape_mask (denoise masks) for 1D (audio), 2D (image), and 3D (video)
spatial dims, including edge cases around batch size, spurious unsqueeze
from upstream nodes, and size mismatches.
"""
import torch
from comfy.samplers import resolve_areas_and_cond_masks_multidim
from comfy.utils import reshape_mask
def make_cond(mask):
@ -228,3 +231,106 @@ class TestAreaResolution:
resolve_areas_and_cond_masks_multidim(conds, (100,), "cpu")
area = conds[0]["area"]
assert area == (50, 25)
# ============================================================
# reshape_mask — mask reshaping for F.interpolate
# ============================================================
class TestReshapeMask1D:
"""Tests for reshape_mask with 1D output (e.g. audio with noise shape [B, C, L])."""
def test_4d_input_same_length(self):
"""[1, 1, 1, L] input (typical from pipeline) — should reshape and expand channels."""
mask = torch.ones(1, 1, 1, 100)
result = reshape_mask(mask, torch.Size([1, 64, 100]))
assert result.shape == (1, 64, 100)
def test_4d_input_resize(self):
"""[1, 1, 1, L1] input resized to different length."""
mask = torch.ones(1, 1, 1, 50)
result = reshape_mask(mask, torch.Size([1, 64, 100]))
assert result.shape == (1, 64, 100)
def test_3d_input(self):
"""[1, 1, L] input — should work directly."""
mask = torch.ones(1, 1, 100)
result = reshape_mask(mask, torch.Size([1, 64, 100]))
assert result.shape == (1, 64, 100)
def test_2d_input(self):
"""[B, L] input — should reshape to [B, 1, L]."""
mask = torch.ones(1, 50)
result = reshape_mask(mask, torch.Size([1, 64, 100]))
assert result.shape == (1, 64, 100)
def test_1d_input(self):
"""[L] input — should reshape to [1, 1, L]."""
mask = torch.ones(50)
result = reshape_mask(mask, torch.Size([1, 64, 100]))
assert result.shape == (1, 64, 100)
def test_channel_repeat(self):
"""Mask with 1 channel should repeat to match output channels."""
mask = torch.full((1, 1, 1, 100), 0.5)
result = reshape_mask(mask, torch.Size([1, 32, 100]))
assert result.shape == (1, 32, 100)
torch.testing.assert_close(result, torch.full_like(result, 0.5))
def test_batch_repeat(self):
"""Single-batch mask should repeat to match output batch size."""
mask = torch.full((1, 1, 1, 100), 0.7)
result = reshape_mask(mask, torch.Size([4, 64, 100]))
assert result.shape == (4, 64, 100)
def test_values_preserved_no_resize(self):
"""Values should be preserved when no resize is needed."""
values = torch.tensor([[[0.0, 0.5, 1.0]]]) # [1, 1, 3]
result = reshape_mask(values, torch.Size([1, 1, 3]))
torch.testing.assert_close(result, values)
def test_interpolation_values(self):
"""Linear interpolation should produce sensible intermediate values."""
mask = torch.tensor([[[[0.0, 1.0]]]]) # [1, 1, 1, 2]
result = reshape_mask(mask, torch.Size([1, 1, 4]))
assert result.shape == (1, 1, 4)
# Should interpolate from 0 to 1
assert result[0, 0, 0].item() < result[0, 0, -1].item()
class TestReshapeMask2D:
"""Regression tests for reshape_mask with 2D output (image models)."""
def test_standard_resize(self):
"""[1, 1, H, W] mask resized to different resolution."""
mask = torch.ones(1, 1, 32, 32)
result = reshape_mask(mask, torch.Size([1, 4, 64, 64]))
assert result.shape == (1, 4, 64, 64)
def test_same_size(self):
"""[1, 1, H, W] mask with matching size — no resize needed."""
mask = torch.rand(1, 1, 64, 64)
result = reshape_mask(mask, torch.Size([1, 4, 64, 64]))
assert result.shape == (1, 4, 64, 64)
def test_3d_input(self):
"""[B, H, W] input — should reshape to [B, 1, H, W]."""
mask = torch.ones(1, 32, 32)
result = reshape_mask(mask, torch.Size([1, 4, 64, 64]))
assert result.shape == (1, 4, 64, 64)
class TestReshapeMask3D:
"""Regression tests for reshape_mask with 3D output (video models)."""
def test_standard_resize(self):
"""[1, 1, T, H, W] mask resized to different resolution."""
mask = torch.ones(1, 1, 8, 32, 32)
result = reshape_mask(mask, torch.Size([1, 4, 8, 64, 64]))
assert result.shape == (1, 4, 8, 64, 64)
def test_4d_input(self):
"""[B, T, H, W] input — should reshape to [1, 1, T, H, W]."""
mask = torch.ones(1, 8, 32, 32)
result = reshape_mask(mask, torch.Size([1, 4, 8, 64, 64]))
assert result.shape == (1, 4, 8, 64, 64)