mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-30 12:22:37 +08:00
comfy casting
This commit is contained in:
parent
37f80907bc
commit
6a351d2f72
@ -1,4 +1,5 @@
|
||||
import torch
|
||||
import comfy.ops
|
||||
import numpy as np
|
||||
import torch.nn as nn
|
||||
from functools import partial
|
||||
@ -422,7 +423,7 @@ class DeformableConv2d(nn.Module):
|
||||
padding=self.padding,
|
||||
bias=True, device=device, dtype=dtype)
|
||||
|
||||
self.regular_conv = torch.nn.Conv2d(in_channels,
|
||||
self.regular_conv = operations.Conv2d(in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
@ -432,15 +433,18 @@ class DeformableConv2d(nn.Module):
|
||||
def forward(self, x):
|
||||
offset = self.offset_conv(x)
|
||||
modulator = 2. * torch.sigmoid(self.modulator_conv(x))
|
||||
weight, bias, offload_info = comfy.ops.cast_bias_weight(self.regular_conv, x, offloadable=True)
|
||||
|
||||
x = deform_conv2d(
|
||||
input=x,
|
||||
offset=offset,
|
||||
weight=self.regular_conv.weight,
|
||||
bias=self.regular_conv.bias,
|
||||
weight=weight,
|
||||
bias=None,
|
||||
padding=self.padding,
|
||||
mask=modulator,
|
||||
stride=self.stride,
|
||||
)
|
||||
comfy.ops.uncast_bias_weight(self.regular_conv, weight, bias, offload_info)
|
||||
return x
|
||||
|
||||
class BasicDecBlk(nn.Module):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user