Fix dtype issue with hidream o1 (#13849)

This commit is contained in:
comfyanonymous 2026-05-11 20:53:13 -07:00 committed by GitHub
parent 8e53f001a4
commit 0155ddcbe3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -451,9 +451,8 @@ class Qwen35VisionPatchEmbed(nn.Module):
self.proj = ops.Conv3d(self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=True, device=device, dtype=dtype)
def forward(self, x):
target_dtype = self.proj.weight.dtype
x = x.view(-1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size)
return self.proj(x.to(target_dtype)).view(-1, self.embed_dim)
return self.proj(x).view(-1, self.embed_dim)
class Qwen35VisionMLP(nn.Module):