mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-02 05:12:36 +08:00
comfy casting
This commit is contained in:
parent
37f80907bc
commit
6a351d2f72
@ -1,4 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
|
import comfy.ops
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from functools import partial
|
from functools import partial
|
||||||
@ -422,7 +423,7 @@ class DeformableConv2d(nn.Module):
|
|||||||
padding=self.padding,
|
padding=self.padding,
|
||||||
bias=True, device=device, dtype=dtype)
|
bias=True, device=device, dtype=dtype)
|
||||||
|
|
||||||
self.regular_conv = torch.nn.Conv2d(in_channels,
|
self.regular_conv = operations.Conv2d(in_channels,
|
||||||
out_channels=out_channels,
|
out_channels=out_channels,
|
||||||
kernel_size=kernel_size,
|
kernel_size=kernel_size,
|
||||||
stride=stride,
|
stride=stride,
|
||||||
@ -432,15 +433,18 @@ class DeformableConv2d(nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
offset = self.offset_conv(x)
|
offset = self.offset_conv(x)
|
||||||
modulator = 2. * torch.sigmoid(self.modulator_conv(x))
|
modulator = 2. * torch.sigmoid(self.modulator_conv(x))
|
||||||
|
weight, bias, offload_info = comfy.ops.cast_bias_weight(self.regular_conv, x, offloadable=True)
|
||||||
|
|
||||||
x = deform_conv2d(
|
x = deform_conv2d(
|
||||||
input=x,
|
input=x,
|
||||||
offset=offset,
|
offset=offset,
|
||||||
weight=self.regular_conv.weight,
|
weight=weight,
|
||||||
bias=self.regular_conv.bias,
|
bias=None,
|
||||||
padding=self.padding,
|
padding=self.padding,
|
||||||
mask=modulator,
|
mask=modulator,
|
||||||
stride=self.stride,
|
stride=self.stride,
|
||||||
)
|
)
|
||||||
|
comfy.ops.uncast_bias_weight(self.regular_conv, weight, bias, offload_info)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
class BasicDecBlk(nn.Module):
|
class BasicDecBlk(nn.Module):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user