comfy casting

This commit is contained in:
Yousef Rafat 2026-04-17 00:57:00 +02:00
parent 37f80907bc
commit 6a351d2f72

View File

@ -1,4 +1,5 @@
import torch import torch
import comfy.ops
import numpy as np import numpy as np
import torch.nn as nn import torch.nn as nn
from functools import partial from functools import partial
@ -422,7 +423,7 @@ class DeformableConv2d(nn.Module):
padding=self.padding, padding=self.padding,
bias=True, device=device, dtype=dtype) bias=True, device=device, dtype=dtype)
self.regular_conv = torch.nn.Conv2d(in_channels, self.regular_conv = operations.Conv2d(in_channels,
out_channels=out_channels, out_channels=out_channels,
kernel_size=kernel_size, kernel_size=kernel_size,
stride=stride, stride=stride,
@ -432,15 +433,18 @@ class DeformableConv2d(nn.Module):
def forward(self, x): def forward(self, x):
offset = self.offset_conv(x) offset = self.offset_conv(x)
modulator = 2. * torch.sigmoid(self.modulator_conv(x)) modulator = 2. * torch.sigmoid(self.modulator_conv(x))
weight, bias, offload_info = comfy.ops.cast_bias_weight(self.regular_conv, x, offloadable=True)
x = deform_conv2d( x = deform_conv2d(
input=x, input=x,
offset=offset, offset=offset,
weight=self.regular_conv.weight, weight=weight,
bias=self.regular_conv.bias, bias=None,
padding=self.padding, padding=self.padding,
mask=modulator, mask=modulator,
stride=self.stride, stride=self.stride,
) )
comfy.ops.uncast_bias_weight(self.regular_conv, weight, bias, offload_info)
return x return x
class BasicDecBlk(nn.Module): class BasicDecBlk(nn.Module):