comfy casting

This commit is contained in:
Yousef Rafat 2026-04-17 00:57:00 +02:00
parent 37f80907bc
commit 6a351d2f72

View File

@ -1,4 +1,5 @@
import torch
import comfy.ops
import numpy as np
import torch.nn as nn
from functools import partial
@ -422,7 +423,7 @@ class DeformableConv2d(nn.Module):
padding=self.padding,
bias=True, device=device, dtype=dtype)
self.regular_conv = torch.nn.Conv2d(in_channels,
self.regular_conv = operations.Conv2d(in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
@ -432,15 +433,18 @@ class DeformableConv2d(nn.Module):
def forward(self, x):
offset = self.offset_conv(x)
modulator = 2. * torch.sigmoid(self.modulator_conv(x))
weight, bias, offload_info = comfy.ops.cast_bias_weight(self.regular_conv, x, offloadable=True)
x = deform_conv2d(
input=x,
offset=offset,
weight=self.regular_conv.weight,
bias=self.regular_conv.bias,
weight=weight,
bias=None,
padding=self.padding,
mask=modulator,
stride=self.stride,
)
comfy.ops.uncast_bias_weight(self.regular_conv, weight, bias, offload_info)
return x
class BasicDecBlk(nn.Module):