Avoid ROCm Conv3d crash in Qwen35 vision patch embedding by using equivalent linear projection

This commit is contained in:
Peter Willemsen 2026-06-01 19:46:26 +02:00
parent 0b610bd63a
commit 39c12cf789
No known key found for this signature in database
GPG Key ID: 80C2AB4961493721

View File

@ -452,6 +452,14 @@ class Qwen35VisionPatchEmbed(nn.Module):
def forward(self, x):
x = x.view(-1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size)
if (
comfy.model_management.is_amd()
and x.is_cuda
and x.dtype in (torch.float16, torch.bfloat16)
):
# This Conv3d is a full-patch projection, equivalent to Linear.
# Avoid the ROCm/MIOpen reduced-precision Conv3d kernel that can segfault.
return F.linear(x.flatten(1), self.proj.weight.flatten(1), self.proj.bias)
return self.proj(x).view(-1, self.embed_dim)