From 6a351d2f723e3a2cc6e10b5872b8d0a81790b8e4 Mon Sep 17 00:00:00 2001 From: Yousef Rafat <81116377+yousef-rafat@users.noreply.github.com> Date: Fri, 17 Apr 2026 00:57:00 +0200 Subject: [PATCH] comfy casting --- comfy/image_encoders/birefnet.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/comfy/image_encoders/birefnet.py b/comfy/image_encoders/birefnet.py index 167c8c067..25ca5b57e 100644 --- a/comfy/image_encoders/birefnet.py +++ b/comfy/image_encoders/birefnet.py @@ -1,4 +1,5 @@ import torch +import comfy.ops import numpy as np import torch.nn as nn from functools import partial @@ -422,7 +423,7 @@ class DeformableConv2d(nn.Module): padding=self.padding, bias=True, device=device, dtype=dtype) - self.regular_conv = torch.nn.Conv2d(in_channels, + self.regular_conv = operations.Conv2d(in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, @@ -432,15 +433,18 @@ class DeformableConv2d(nn.Module): def forward(self, x): offset = self.offset_conv(x) modulator = 2. * torch.sigmoid(self.modulator_conv(x)) + weight, bias, offload_info = comfy.ops.cast_bias_weight(self.regular_conv, x, offloadable=True) + x = deform_conv2d( input=x, offset=offset, - weight=self.regular_conv.weight, - bias=self.regular_conv.bias, + weight=weight, + bias=None, padding=self.padding, mask=modulator, stride=self.stride, ) + comfy.ops.uncast_bias_weight(self.regular_conv, weight, bias, offload_info) return x class BasicDecBlk(nn.Module):