mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-15 16:02:32 +08:00
Merge branch 'comfyanonymous:master' into improve/maskeditor
This commit is contained in:
commit
e749c1f7ad
@ -283,7 +283,7 @@ class ControlNet(nn.Module):
|
|||||||
return TimestepEmbedSequential(zero_module(operations.conv_nd(self.dims, channels, channels, 1, padding=0)))
|
return TimestepEmbedSequential(zero_module(operations.conv_nd(self.dims, channels, channels, 1, padding=0)))
|
||||||
|
|
||||||
def forward(self, x, hint, timesteps, context, y=None, **kwargs):
|
def forward(self, x, hint, timesteps, context, y=None, **kwargs):
|
||||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype)
|
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
|
||||||
emb = self.time_embed(t_emb)
|
emb = self.time_embed(t_emb)
|
||||||
|
|
||||||
guided_hint = self.input_hint_block(hint, emb, context)
|
guided_hint = self.input_hint_block(hint, emb, context)
|
||||||
@ -295,7 +295,7 @@ class ControlNet(nn.Module):
|
|||||||
assert y.shape[0] == x.shape[0]
|
assert y.shape[0] == x.shape[0]
|
||||||
emb = emb + self.label_emb(y)
|
emb = emb + self.label_emb(y)
|
||||||
|
|
||||||
h = x.type(self.dtype)
|
h = x
|
||||||
for module, zero_conv in zip(self.input_blocks, self.zero_convs):
|
for module, zero_conv in zip(self.input_blocks, self.zero_convs):
|
||||||
if guided_hint is not None:
|
if guided_hint is not None:
|
||||||
h = module(h, emb, context)
|
h = module(h, emb, context)
|
||||||
|
|||||||
@ -55,7 +55,10 @@ fp_group = parser.add_mutually_exclusive_group()
|
|||||||
fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).")
|
fp_group.add_argument("--force-fp32", action="store_true", help="Force fp32 (If this makes your GPU work better please report it).")
|
||||||
fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.")
|
fp_group.add_argument("--force-fp16", action="store_true", help="Force fp16.")
|
||||||
|
|
||||||
parser.add_argument("--bf16-unet", action="store_true", help="Run the UNET in bf16. This should only be used for testing stuff.")
|
fpunet_group = parser.add_mutually_exclusive_group()
|
||||||
|
fpunet_group.add_argument("--bf16-unet", action="store_true", help="Run the UNET in bf16. This should only be used for testing stuff.")
|
||||||
|
fpunet_group.add_argument("--fp8_e4m3fn-unet", action="store_true", help="Store unet weights in fp8_e4m3fn.")
|
||||||
|
fpunet_group.add_argument("--fp8_e5m2-unet", action="store_true", help="Store unet weights in fp8_e5m2.")
|
||||||
|
|
||||||
fpvae_group = parser.add_mutually_exclusive_group()
|
fpvae_group = parser.add_mutually_exclusive_group()
|
||||||
fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.")
|
fpvae_group.add_argument("--fp16-vae", action="store_true", help="Run the VAE in fp16, might cause black images.")
|
||||||
|
|||||||
@ -1,10 +1,12 @@
|
|||||||
import torch
|
import torch
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
|
import contextlib
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.model_detection
|
import comfy.model_detection
|
||||||
import comfy.model_patcher
|
import comfy.model_patcher
|
||||||
|
import comfy.ops
|
||||||
|
|
||||||
import comfy.cldm.cldm
|
import comfy.cldm.cldm
|
||||||
import comfy.t2i_adapter.adapter
|
import comfy.t2i_adapter.adapter
|
||||||
@ -146,24 +148,31 @@ class ControlNet(ControlBase):
|
|||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
dtype = self.control_model.dtype
|
||||||
|
if comfy.model_management.supports_dtype(self.device, dtype):
|
||||||
|
precision_scope = lambda a: contextlib.nullcontext(a)
|
||||||
|
else:
|
||||||
|
precision_scope = torch.autocast
|
||||||
|
dtype = torch.float32
|
||||||
|
|
||||||
output_dtype = x_noisy.dtype
|
output_dtype = x_noisy.dtype
|
||||||
if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
|
if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]:
|
||||||
if self.cond_hint is not None:
|
if self.cond_hint is not None:
|
||||||
del self.cond_hint
|
del self.cond_hint
|
||||||
self.cond_hint = None
|
self.cond_hint = None
|
||||||
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(self.control_model.dtype).to(self.device)
|
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").to(dtype).to(self.device)
|
||||||
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
||||||
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
||||||
|
|
||||||
|
|
||||||
context = cond['c_crossattn']
|
context = cond['c_crossattn']
|
||||||
y = cond.get('y', None)
|
y = cond.get('y', None)
|
||||||
if y is not None:
|
if y is not None:
|
||||||
y = y.to(self.control_model.dtype)
|
y = y.to(dtype)
|
||||||
timestep = self.model_sampling_current.timestep(t)
|
timestep = self.model_sampling_current.timestep(t)
|
||||||
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
|
x_noisy = self.model_sampling_current.calculate_input(t, x_noisy)
|
||||||
|
|
||||||
control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(self.control_model.dtype), y=y)
|
with precision_scope(comfy.model_management.get_autocast_device(self.device)):
|
||||||
|
control = self.control_model(x=x_noisy.to(dtype), hint=self.cond_hint, timesteps=timestep.float(), context=context.to(dtype), y=y)
|
||||||
return self.control_merge(None, control, control_prev, output_dtype)
|
return self.control_merge(None, control, control_prev, output_dtype)
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
@ -199,7 +208,7 @@ class ControlLoraOps:
|
|||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
if self.up is not None:
|
if self.up is not None:
|
||||||
return torch.nn.functional.linear(input, self.weight.to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias)
|
return torch.nn.functional.linear(input, self.weight.to(input.dtype).to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias)
|
||||||
else:
|
else:
|
||||||
return torch.nn.functional.linear(input, self.weight.to(input.device), self.bias)
|
return torch.nn.functional.linear(input, self.weight.to(input.device), self.bias)
|
||||||
|
|
||||||
@ -238,7 +247,7 @@ class ControlLoraOps:
|
|||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
if self.up is not None:
|
if self.up is not None:
|
||||||
return torch.nn.functional.conv2d(input, self.weight.to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias, self.stride, self.padding, self.dilation, self.groups)
|
return torch.nn.functional.conv2d(input, self.weight.to(input.dtype).to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias, self.stride, self.padding, self.dilation, self.groups)
|
||||||
else:
|
else:
|
||||||
return torch.nn.functional.conv2d(input, self.weight.to(input.device), self.bias, self.stride, self.padding, self.dilation, self.groups)
|
return torch.nn.functional.conv2d(input, self.weight.to(input.device), self.bias, self.stride, self.padding, self.dilation, self.groups)
|
||||||
|
|
||||||
@ -248,6 +257,15 @@ class ControlLoraOps:
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"unsupported dimensions: {dims}")
|
raise ValueError(f"unsupported dimensions: {dims}")
|
||||||
|
|
||||||
|
class Conv3d(comfy.ops.Conv3d):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class GroupNorm(comfy.ops.GroupNorm):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class LayerNorm(comfy.ops.LayerNorm):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class ControlLora(ControlNet):
|
class ControlLora(ControlNet):
|
||||||
def __init__(self, control_weights, global_average_pooling=False, device=None):
|
def __init__(self, control_weights, global_average_pooling=False, device=None):
|
||||||
|
|||||||
@ -83,16 +83,6 @@ class FeedForward(nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.net(x)
|
return self.net(x)
|
||||||
|
|
||||||
|
|
||||||
def zero_module(module):
|
|
||||||
"""
|
|
||||||
Zero out the parameters of a module and return it.
|
|
||||||
"""
|
|
||||||
for p in module.parameters():
|
|
||||||
p.detach().zero_()
|
|
||||||
return module
|
|
||||||
|
|
||||||
|
|
||||||
def Normalize(in_channels, dtype=None, device=None):
|
def Normalize(in_channels, dtype=None, device=None):
|
||||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
|
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
|
||||||
|
|
||||||
@ -414,10 +404,10 @@ class BasicTransformerBlock(nn.Module):
|
|||||||
|
|
||||||
self.attn2 = CrossAttention(query_dim=inner_dim, context_dim=context_dim_attn2,
|
self.attn2 = CrossAttention(query_dim=inner_dim, context_dim=context_dim_attn2,
|
||||||
heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device, operations=operations) # is self-attn if context is none
|
heads=n_heads, dim_head=d_head, dropout=dropout, dtype=dtype, device=device, operations=operations) # is self-attn if context is none
|
||||||
self.norm2 = nn.LayerNorm(inner_dim, dtype=dtype, device=device)
|
self.norm2 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
|
||||||
|
|
||||||
self.norm1 = nn.LayerNorm(inner_dim, dtype=dtype, device=device)
|
self.norm1 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
|
||||||
self.norm3 = nn.LayerNorm(inner_dim, dtype=dtype, device=device)
|
self.norm3 = operations.LayerNorm(inner_dim, dtype=dtype, device=device)
|
||||||
self.checkpoint = checkpoint
|
self.checkpoint = checkpoint
|
||||||
self.n_heads = n_heads
|
self.n_heads = n_heads
|
||||||
self.d_head = d_head
|
self.d_head = d_head
|
||||||
@ -559,7 +549,7 @@ class SpatialTransformer(nn.Module):
|
|||||||
context_dim = [context_dim] * depth
|
context_dim = [context_dim] * depth
|
||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
inner_dim = n_heads * d_head
|
inner_dim = n_heads * d_head
|
||||||
self.norm = Normalize(in_channels, dtype=dtype, device=device)
|
self.norm = operations.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device)
|
||||||
if not use_linear:
|
if not use_linear:
|
||||||
self.proj_in = operations.Conv2d(in_channels,
|
self.proj_in = operations.Conv2d(in_channels,
|
||||||
inner_dim,
|
inner_dim,
|
||||||
|
|||||||
@ -177,7 +177,7 @@ class ResBlock(TimestepBlock):
|
|||||||
padding = kernel_size // 2
|
padding = kernel_size // 2
|
||||||
|
|
||||||
self.in_layers = nn.Sequential(
|
self.in_layers = nn.Sequential(
|
||||||
nn.GroupNorm(32, channels, dtype=dtype, device=device),
|
operations.GroupNorm(32, channels, dtype=dtype, device=device),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
operations.conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device),
|
operations.conv_nd(dims, channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device),
|
||||||
)
|
)
|
||||||
@ -206,12 +206,11 @@ class ResBlock(TimestepBlock):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
self.out_layers = nn.Sequential(
|
self.out_layers = nn.Sequential(
|
||||||
nn.GroupNorm(32, self.out_channels, dtype=dtype, device=device),
|
operations.GroupNorm(32, self.out_channels, dtype=dtype, device=device),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
nn.Dropout(p=dropout),
|
nn.Dropout(p=dropout),
|
||||||
zero_module(
|
operations.conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device)
|
||||||
operations.conv_nd(dims, self.out_channels, self.out_channels, kernel_size, padding=padding, dtype=dtype, device=device)
|
,
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.out_channels == channels:
|
if self.out_channels == channels:
|
||||||
@ -810,13 +809,13 @@ class UNetModel(nn.Module):
|
|||||||
self._feature_size += ch
|
self._feature_size += ch
|
||||||
|
|
||||||
self.out = nn.Sequential(
|
self.out = nn.Sequential(
|
||||||
nn.GroupNorm(32, ch, dtype=self.dtype, device=device),
|
operations.GroupNorm(32, ch, dtype=self.dtype, device=device),
|
||||||
nn.SiLU(),
|
nn.SiLU(),
|
||||||
zero_module(operations.conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype, device=device)),
|
zero_module(operations.conv_nd(dims, model_channels, out_channels, 3, padding=1, dtype=self.dtype, device=device)),
|
||||||
)
|
)
|
||||||
if self.predict_codebook_ids:
|
if self.predict_codebook_ids:
|
||||||
self.id_predictor = nn.Sequential(
|
self.id_predictor = nn.Sequential(
|
||||||
nn.GroupNorm(32, ch, dtype=self.dtype, device=device),
|
operations.GroupNorm(32, ch, dtype=self.dtype, device=device),
|
||||||
operations.conv_nd(dims, model_channels, n_embed, 1, dtype=self.dtype, device=device),
|
operations.conv_nd(dims, model_channels, n_embed, 1, dtype=self.dtype, device=device),
|
||||||
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
|
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
|
||||||
)
|
)
|
||||||
@ -842,14 +841,14 @@ class UNetModel(nn.Module):
|
|||||||
self.num_classes is not None
|
self.num_classes is not None
|
||||||
), "must specify y if and only if the model is class-conditional"
|
), "must specify y if and only if the model is class-conditional"
|
||||||
hs = []
|
hs = []
|
||||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype)
|
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
|
||||||
emb = self.time_embed(t_emb)
|
emb = self.time_embed(t_emb)
|
||||||
|
|
||||||
if self.num_classes is not None:
|
if self.num_classes is not None:
|
||||||
assert y.shape[0] == x.shape[0]
|
assert y.shape[0] == x.shape[0]
|
||||||
emb = emb + self.label_emb(y)
|
emb = emb + self.label_emb(y)
|
||||||
|
|
||||||
h = x.type(self.dtype)
|
h = x
|
||||||
for id, module in enumerate(self.input_blocks):
|
for id, module in enumerate(self.input_blocks):
|
||||||
transformer_options["block"] = ("input", id)
|
transformer_options["block"] = ("input", id)
|
||||||
h = forward_timestep_embed(module, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
|
h = forward_timestep_embed(module, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
|
||||||
|
|||||||
@ -5,6 +5,7 @@ from comfy.ldm.modules.diffusionmodules.openaimodel import Timestep
|
|||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.conds
|
import comfy.conds
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
import contextlib
|
||||||
from . import utils
|
from . import utils
|
||||||
|
|
||||||
class ModelType(Enum):
|
class ModelType(Enum):
|
||||||
@ -61,6 +62,13 @@ class BaseModel(torch.nn.Module):
|
|||||||
|
|
||||||
context = c_crossattn
|
context = c_crossattn
|
||||||
dtype = self.get_dtype()
|
dtype = self.get_dtype()
|
||||||
|
|
||||||
|
if comfy.model_management.supports_dtype(xc.device, dtype):
|
||||||
|
precision_scope = lambda a: contextlib.nullcontext(a)
|
||||||
|
else:
|
||||||
|
precision_scope = torch.autocast
|
||||||
|
dtype = torch.float32
|
||||||
|
|
||||||
xc = xc.to(dtype)
|
xc = xc.to(dtype)
|
||||||
t = self.model_sampling.timestep(t).float()
|
t = self.model_sampling.timestep(t).float()
|
||||||
context = context.to(dtype)
|
context = context.to(dtype)
|
||||||
@ -70,7 +78,10 @@ class BaseModel(torch.nn.Module):
|
|||||||
if hasattr(extra, "to"):
|
if hasattr(extra, "to"):
|
||||||
extra = extra.to(dtype)
|
extra = extra.to(dtype)
|
||||||
extra_conds[o] = extra
|
extra_conds[o] = extra
|
||||||
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
|
|
||||||
|
with precision_scope(comfy.model_management.get_autocast_device(xc.device)):
|
||||||
|
model_output = self.diffusion_model(xc, t, context=context, control=control, transformer_options=transformer_options, **extra_conds).float()
|
||||||
|
|
||||||
return self.model_sampling.calculate_denoised(sigma, model_output, x)
|
return self.model_sampling.calculate_denoised(sigma, model_output, x)
|
||||||
|
|
||||||
def get_dtype(self):
|
def get_dtype(self):
|
||||||
|
|||||||
@ -430,6 +430,13 @@ def dtype_size(dtype):
|
|||||||
dtype_size = 4
|
dtype_size = 4
|
||||||
if dtype == torch.float16 or dtype == torch.bfloat16:
|
if dtype == torch.float16 or dtype == torch.bfloat16:
|
||||||
dtype_size = 2
|
dtype_size = 2
|
||||||
|
elif dtype == torch.float32:
|
||||||
|
dtype_size = 4
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
dtype_size = dtype.itemsize
|
||||||
|
except: #Old pytorch doesn't have .itemsize
|
||||||
|
pass
|
||||||
return dtype_size
|
return dtype_size
|
||||||
|
|
||||||
def unet_offload_device():
|
def unet_offload_device():
|
||||||
@ -459,6 +466,10 @@ def unet_inital_load_device(parameters, dtype):
|
|||||||
def unet_dtype(device=None, model_params=0):
|
def unet_dtype(device=None, model_params=0):
|
||||||
if args.bf16_unet:
|
if args.bf16_unet:
|
||||||
return torch.bfloat16
|
return torch.bfloat16
|
||||||
|
if args.fp8_e4m3fn_unet:
|
||||||
|
return torch.float8_e4m3fn
|
||||||
|
if args.fp8_e5m2_unet:
|
||||||
|
return torch.float8_e5m2
|
||||||
if should_use_fp16(device=device, model_params=model_params):
|
if should_use_fp16(device=device, model_params=model_params):
|
||||||
return torch.float16
|
return torch.float16
|
||||||
return torch.float32
|
return torch.float32
|
||||||
@ -515,6 +526,17 @@ def get_autocast_device(dev):
|
|||||||
return dev.type
|
return dev.type
|
||||||
return "cuda"
|
return "cuda"
|
||||||
|
|
||||||
|
def supports_dtype(device, dtype): #TODO
|
||||||
|
if dtype == torch.float32:
|
||||||
|
return True
|
||||||
|
if torch.device("cpu") == device:
|
||||||
|
return False
|
||||||
|
if dtype == torch.float16:
|
||||||
|
return True
|
||||||
|
if dtype == torch.bfloat16:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
def cast_to_device(tensor, device, dtype, copy=False):
|
def cast_to_device(tensor, device, dtype, copy=False):
|
||||||
device_supports_cast = False
|
device_supports_cast = False
|
||||||
if tensor.dtype == torch.float32 or tensor.dtype == torch.float16:
|
if tensor.dtype == torch.float32 or tensor.dtype == torch.float16:
|
||||||
|
|||||||
@ -13,6 +13,14 @@ class Conv3d(torch.nn.Conv3d):
|
|||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
class GroupNorm(torch.nn.GroupNorm):
|
||||||
|
def reset_parameters(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
class LayerNorm(torch.nn.LayerNorm):
|
||||||
|
def reset_parameters(self):
|
||||||
|
return None
|
||||||
|
|
||||||
def conv_nd(dims, *args, **kwargs):
|
def conv_nd(dims, *args, **kwargs):
|
||||||
if dims == 2:
|
if dims == 2:
|
||||||
return Conv2d(*args, **kwargs)
|
return Conv2d(*args, **kwargs)
|
||||||
|
|||||||
@ -101,7 +101,7 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
|
|||||||
samples = samples.cpu()
|
samples = samples.cpu()
|
||||||
|
|
||||||
cleanup_additional_models(models)
|
cleanup_additional_models(models)
|
||||||
cleanup_additional_models(set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control")))
|
cleanup_additional_models(set(get_models_from_cond(positive_copy, "control") + get_models_from_cond(negative_copy, "control")))
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None):
|
def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None):
|
||||||
@ -113,6 +113,6 @@ def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent
|
|||||||
samples = comfy.samplers.sample(real_model, noise, positive_copy, negative_copy, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
|
samples = comfy.samplers.sample(real_model, noise, positive_copy, negative_copy, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
|
||||||
samples = samples.cpu()
|
samples = samples.cpu()
|
||||||
cleanup_additional_models(models)
|
cleanup_additional_models(models)
|
||||||
cleanup_additional_models(set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control")))
|
cleanup_additional_models(set(get_models_from_cond(positive_copy, "control") + get_models_from_cond(negative_copy, "control")))
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
|
|||||||
@ -84,12 +84,16 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
|
|
||||||
self.inner_name = inner_name
|
self.inner_name = inner_name
|
||||||
if dtype is not None:
|
if dtype is not None:
|
||||||
self.transformer.to(dtype)
|
|
||||||
inner_model = getattr(self.transformer, self.inner_name)
|
inner_model = getattr(self.transformer, self.inner_name)
|
||||||
if hasattr(inner_model, "embeddings"):
|
if hasattr(inner_model, "embeddings"):
|
||||||
inner_model.embeddings.to(torch.float32)
|
embeddings_bak = inner_model.embeddings.to(torch.float32)
|
||||||
|
inner_model.embeddings = None
|
||||||
|
self.transformer.to(dtype)
|
||||||
|
inner_model.embeddings = embeddings_bak
|
||||||
else:
|
else:
|
||||||
self.transformer.set_input_embeddings(self.transformer.get_input_embeddings().to(torch.float32))
|
previous_inputs = self.transformer.get_input_embeddings().to(torch.float32, copy=True)
|
||||||
|
self.transformer.to(dtype)
|
||||||
|
self.transformer.set_input_embeddings(previous_inputs)
|
||||||
|
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
if freeze:
|
if freeze:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user