Remove workaround for old pytorch. (#12480)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run

This commit is contained in:
comfyanonymous 2026-02-15 17:43:53 -08:00 committed by GitHub
parent c0370044cd
commit 88e6370527
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -102,19 +102,7 @@ class VideoConv3d(nn.Module):
return self.conv(x)
def interpolate_up(x, scale_factor):
try:
return torch.nn.functional.interpolate(x, scale_factor=scale_factor, mode="nearest")
except: #operation not implemented for bf16
orig_shape = list(x.shape)
out_shape = orig_shape[:2]
for i in range(len(orig_shape) - 2):
out_shape.append(round(orig_shape[i + 2] * scale_factor[i]))
out = torch.empty(out_shape, dtype=x.dtype, layout=x.layout, device=x.device)
split = 8
l = out.shape[1] // split
for i in range(0, out.shape[1], l):
out[:,i:i+l] = torch.nn.functional.interpolate(x[:,i:i+l].to(torch.float32), scale_factor=scale_factor, mode="nearest").to(x.dtype)
return out
return torch.nn.functional.interpolate(x, scale_factor=scale_factor, mode="nearest")
class Upsample(nn.Module):
def __init__(self, in_channels, with_conv, conv_op=ops.Conv2d, scale_factor=2.0):