mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-30 21:43:43 +08:00
Fix reshape_mask for 1D spatial dimensions.
reshape_mask sets scale_mode="linear" for dims==1 but is missing the input reshape to [N, C, W] that the 2D and 3D branches both perform. Add the missing reshape, matching the existing pattern.
This commit is contained in:
parent
abfea891ef
commit
9a5e6233f4
@ -1241,6 +1241,7 @@ def reshape_mask(input_mask, output_shape):
|
||||
dims = len(output_shape) - 2
|
||||
|
||||
if dims == 1:
|
||||
input_mask = input_mask.reshape((-1, 1, input_mask.shape[-1]))
|
||||
scale_mode = "linear"
|
||||
|
||||
if dims == 2:
|
||||
|
||||
@ -1,13 +1,16 @@
|
||||
"""
|
||||
Tests for resolve_areas_and_cond_masks_multidim mask normalization and resizing.
|
||||
Tests for mask handling across arbitrary spatial dimensions.
|
||||
|
||||
Covers 1D (audio), 2D (image), and 3D (video) spatial dims, including edge cases
|
||||
around batch size, spurious unsqueeze from upstream nodes, and size mismatches.
|
||||
Covers resolve_areas_and_cond_masks_multidim (conditioning masks) and
|
||||
reshape_mask (denoise masks) for 1D (audio), 2D (image), and 3D (video)
|
||||
spatial dims, including edge cases around batch size, spurious unsqueeze
|
||||
from upstream nodes, and size mismatches.
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
from comfy.samplers import resolve_areas_and_cond_masks_multidim
|
||||
from comfy.utils import reshape_mask
|
||||
|
||||
|
||||
def make_cond(mask):
|
||||
@ -228,3 +231,106 @@ class TestAreaResolution:
|
||||
resolve_areas_and_cond_masks_multidim(conds, (100,), "cpu")
|
||||
area = conds[0]["area"]
|
||||
assert area == (50, 25)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# reshape_mask — mask reshaping for F.interpolate
|
||||
# ============================================================
|
||||
|
||||
class TestReshapeMask1D:
|
||||
"""Tests for reshape_mask with 1D output (e.g. audio with noise shape [B, C, L])."""
|
||||
|
||||
def test_4d_input_same_length(self):
|
||||
"""[1, 1, 1, L] input (typical from pipeline) — should reshape and expand channels."""
|
||||
mask = torch.ones(1, 1, 1, 100)
|
||||
result = reshape_mask(mask, torch.Size([1, 64, 100]))
|
||||
assert result.shape == (1, 64, 100)
|
||||
|
||||
def test_4d_input_resize(self):
|
||||
"""[1, 1, 1, L1] input resized to different length."""
|
||||
mask = torch.ones(1, 1, 1, 50)
|
||||
result = reshape_mask(mask, torch.Size([1, 64, 100]))
|
||||
assert result.shape == (1, 64, 100)
|
||||
|
||||
def test_3d_input(self):
|
||||
"""[1, 1, L] input — should work directly."""
|
||||
mask = torch.ones(1, 1, 100)
|
||||
result = reshape_mask(mask, torch.Size([1, 64, 100]))
|
||||
assert result.shape == (1, 64, 100)
|
||||
|
||||
def test_2d_input(self):
|
||||
"""[B, L] input — should reshape to [B, 1, L]."""
|
||||
mask = torch.ones(1, 50)
|
||||
result = reshape_mask(mask, torch.Size([1, 64, 100]))
|
||||
assert result.shape == (1, 64, 100)
|
||||
|
||||
def test_1d_input(self):
|
||||
"""[L] input — should reshape to [1, 1, L]."""
|
||||
mask = torch.ones(50)
|
||||
result = reshape_mask(mask, torch.Size([1, 64, 100]))
|
||||
assert result.shape == (1, 64, 100)
|
||||
|
||||
def test_channel_repeat(self):
|
||||
"""Mask with 1 channel should repeat to match output channels."""
|
||||
mask = torch.full((1, 1, 1, 100), 0.5)
|
||||
result = reshape_mask(mask, torch.Size([1, 32, 100]))
|
||||
assert result.shape == (1, 32, 100)
|
||||
torch.testing.assert_close(result, torch.full_like(result, 0.5))
|
||||
|
||||
def test_batch_repeat(self):
|
||||
"""Single-batch mask should repeat to match output batch size."""
|
||||
mask = torch.full((1, 1, 1, 100), 0.7)
|
||||
result = reshape_mask(mask, torch.Size([4, 64, 100]))
|
||||
assert result.shape == (4, 64, 100)
|
||||
|
||||
def test_values_preserved_no_resize(self):
|
||||
"""Values should be preserved when no resize is needed."""
|
||||
values = torch.tensor([[[0.0, 0.5, 1.0]]]) # [1, 1, 3]
|
||||
result = reshape_mask(values, torch.Size([1, 1, 3]))
|
||||
torch.testing.assert_close(result, values)
|
||||
|
||||
def test_interpolation_values(self):
|
||||
"""Linear interpolation should produce sensible intermediate values."""
|
||||
mask = torch.tensor([[[[0.0, 1.0]]]]) # [1, 1, 1, 2]
|
||||
result = reshape_mask(mask, torch.Size([1, 1, 4]))
|
||||
assert result.shape == (1, 1, 4)
|
||||
# Should interpolate from 0 to 1
|
||||
assert result[0, 0, 0].item() < result[0, 0, -1].item()
|
||||
|
||||
|
||||
class TestReshapeMask2D:
|
||||
"""Regression tests for reshape_mask with 2D output (image models)."""
|
||||
|
||||
def test_standard_resize(self):
|
||||
"""[1, 1, H, W] mask resized to different resolution."""
|
||||
mask = torch.ones(1, 1, 32, 32)
|
||||
result = reshape_mask(mask, torch.Size([1, 4, 64, 64]))
|
||||
assert result.shape == (1, 4, 64, 64)
|
||||
|
||||
def test_same_size(self):
|
||||
"""[1, 1, H, W] mask with matching size — no resize needed."""
|
||||
mask = torch.rand(1, 1, 64, 64)
|
||||
result = reshape_mask(mask, torch.Size([1, 4, 64, 64]))
|
||||
assert result.shape == (1, 4, 64, 64)
|
||||
|
||||
def test_3d_input(self):
|
||||
"""[B, H, W] input — should reshape to [B, 1, H, W]."""
|
||||
mask = torch.ones(1, 32, 32)
|
||||
result = reshape_mask(mask, torch.Size([1, 4, 64, 64]))
|
||||
assert result.shape == (1, 4, 64, 64)
|
||||
|
||||
|
||||
class TestReshapeMask3D:
|
||||
"""Regression tests for reshape_mask with 3D output (video models)."""
|
||||
|
||||
def test_standard_resize(self):
|
||||
"""[1, 1, T, H, W] mask resized to different resolution."""
|
||||
mask = torch.ones(1, 1, 8, 32, 32)
|
||||
result = reshape_mask(mask, torch.Size([1, 4, 8, 64, 64]))
|
||||
assert result.shape == (1, 4, 8, 64, 64)
|
||||
|
||||
def test_4d_input(self):
|
||||
"""[B, T, H, W] input — should reshape to [1, 1, T, H, W]."""
|
||||
mask = torch.ones(1, 8, 32, 32)
|
||||
result = reshape_mask(mask, torch.Size([1, 4, 8, 64, 64]))
|
||||
assert result.shape == (1, 4, 8, 64, 64)
|
||||
Loading…
Reference in New Issue
Block a user