mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-10 13:32:36 +08:00
Merge branch 'comfyanonymous:master' into multiple_workflows
This commit is contained in:
commit
f9c999a537
@ -279,7 +279,7 @@ class ControlNet(nn.Module):
|
|||||||
return TimestepEmbedSequential(zero_module(operations.conv_nd(self.dims, channels, channels, 1, padding=0)))
|
return TimestepEmbedSequential(zero_module(operations.conv_nd(self.dims, channels, channels, 1, padding=0)))
|
||||||
|
|
||||||
def forward(self, x, hint, timesteps, context, y=None, **kwargs):
|
def forward(self, x, hint, timesteps, context, y=None, **kwargs):
|
||||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype)
|
||||||
emb = self.time_embed(t_emb)
|
emb = self.time_embed(t_emb)
|
||||||
|
|
||||||
guided_hint = self.input_hint_block(hint, emb, context)
|
guided_hint = self.input_hint_block(hint, emb, context)
|
||||||
@ -287,9 +287,6 @@ class ControlNet(nn.Module):
|
|||||||
outs = []
|
outs = []
|
||||||
|
|
||||||
hs = []
|
hs = []
|
||||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
|
||||||
emb = self.time_embed(t_emb)
|
|
||||||
|
|
||||||
if self.num_classes is not None:
|
if self.num_classes is not None:
|
||||||
assert y.shape[0] == x.shape[0]
|
assert y.shape[0] == x.shape[0]
|
||||||
emb = emb + self.label_emb(y)
|
emb = emb + self.label_emb(y)
|
||||||
|
|||||||
@ -58,6 +58,8 @@ fpvae_group.add_argument("--bf16-vae", action="store_true", help="Run the VAE in
|
|||||||
|
|
||||||
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
|
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
|
||||||
|
|
||||||
|
parser.add_argument("--disable-ipex-optimize", action="store_true", help="Disables ipex.optimize when loading models with Intel GPUs.")
|
||||||
|
|
||||||
class LatentPreviewMethod(enum.Enum):
|
class LatentPreviewMethod(enum.Enum):
|
||||||
NoPreviews = "none"
|
NoPreviews = "none"
|
||||||
Auto = "auto"
|
Auto = "auto"
|
||||||
|
|||||||
@ -632,7 +632,9 @@ class UNetModel(nn.Module):
|
|||||||
transformer_options["block"] = ("middle", 0)
|
transformer_options["block"] = ("middle", 0)
|
||||||
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options)
|
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options)
|
||||||
if control is not None and 'middle' in control and len(control['middle']) > 0:
|
if control is not None and 'middle' in control and len(control['middle']) > 0:
|
||||||
h += control['middle'].pop()
|
ctrl = control['middle'].pop()
|
||||||
|
if ctrl is not None:
|
||||||
|
h += ctrl
|
||||||
|
|
||||||
for id, module in enumerate(self.output_blocks):
|
for id, module in enumerate(self.output_blocks):
|
||||||
transformer_options["block"] = ("output", id)
|
transformer_options["block"] = ("output", id)
|
||||||
|
|||||||
@ -88,8 +88,10 @@ def get_total_memory(dev=None, torch_total_too=False):
|
|||||||
mem_total = 1024 * 1024 * 1024 #TODO
|
mem_total = 1024 * 1024 * 1024 #TODO
|
||||||
mem_total_torch = mem_total
|
mem_total_torch = mem_total
|
||||||
elif xpu_available:
|
elif xpu_available:
|
||||||
|
stats = torch.xpu.memory_stats(dev)
|
||||||
|
mem_reserved = stats['reserved_bytes.all.current']
|
||||||
mem_total = torch.xpu.get_device_properties(dev).total_memory
|
mem_total = torch.xpu.get_device_properties(dev).total_memory
|
||||||
mem_total_torch = mem_total
|
mem_total_torch = mem_reserved
|
||||||
else:
|
else:
|
||||||
stats = torch.cuda.memory_stats(dev)
|
stats = torch.cuda.memory_stats(dev)
|
||||||
mem_reserved = stats['reserved_bytes.all.current']
|
mem_reserved = stats['reserved_bytes.all.current']
|
||||||
@ -208,6 +210,7 @@ if DISABLE_SMART_MEMORY:
|
|||||||
print("Disabling smart memory management")
|
print("Disabling smart memory management")
|
||||||
|
|
||||||
def get_torch_device_name(device):
|
def get_torch_device_name(device):
|
||||||
|
global xpu_available
|
||||||
if hasattr(device, 'type'):
|
if hasattr(device, 'type'):
|
||||||
if device.type == "cuda":
|
if device.type == "cuda":
|
||||||
try:
|
try:
|
||||||
@ -217,6 +220,8 @@ def get_torch_device_name(device):
|
|||||||
return "{} {} : {}".format(device, torch.cuda.get_device_name(device), allocator_backend)
|
return "{} {} : {}".format(device, torch.cuda.get_device_name(device), allocator_backend)
|
||||||
else:
|
else:
|
||||||
return "{}".format(device.type)
|
return "{}".format(device.type)
|
||||||
|
elif xpu_available:
|
||||||
|
return "{} {}".format(device, torch.xpu.get_device_name(device))
|
||||||
else:
|
else:
|
||||||
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
|
return "CUDA {}: {}".format(device, torch.cuda.get_device_name(device))
|
||||||
|
|
||||||
@ -244,6 +249,7 @@ class LoadedModel:
|
|||||||
return self.model_memory()
|
return self.model_memory()
|
||||||
|
|
||||||
def model_load(self, lowvram_model_memory=0):
|
def model_load(self, lowvram_model_memory=0):
|
||||||
|
global xpu_available
|
||||||
patch_model_to = None
|
patch_model_to = None
|
||||||
if lowvram_model_memory == 0:
|
if lowvram_model_memory == 0:
|
||||||
patch_model_to = self.device
|
patch_model_to = self.device
|
||||||
@ -264,6 +270,9 @@ class LoadedModel:
|
|||||||
accelerate.dispatch_model(self.real_model, device_map=device_map, main_device=self.device)
|
accelerate.dispatch_model(self.real_model, device_map=device_map, main_device=self.device)
|
||||||
self.model_accelerated = True
|
self.model_accelerated = True
|
||||||
|
|
||||||
|
if xpu_available and not args.disable_ipex_optimize:
|
||||||
|
self.real_model = torch.xpu.optimize(self.real_model.eval(), inplace=True, auto_kernel_selection=True, graph_mode=True)
|
||||||
|
|
||||||
return self.real_model
|
return self.real_model
|
||||||
|
|
||||||
def model_unload(self):
|
def model_unload(self):
|
||||||
@ -397,6 +406,9 @@ def unet_inital_load_device(parameters, dtype):
|
|||||||
return torch_dev
|
return torch_dev
|
||||||
|
|
||||||
cpu_dev = torch.device("cpu")
|
cpu_dev = torch.device("cpu")
|
||||||
|
if DISABLE_SMART_MEMORY:
|
||||||
|
return cpu_dev
|
||||||
|
|
||||||
dtype_size = 4
|
dtype_size = 4
|
||||||
if dtype == torch.float16 or dtype == torch.bfloat16:
|
if dtype == torch.float16 or dtype == torch.bfloat16:
|
||||||
dtype_size = 2
|
dtype_size = 2
|
||||||
@ -497,8 +509,12 @@ def get_free_memory(dev=None, torch_free_too=False):
|
|||||||
mem_free_total = 1024 * 1024 * 1024 #TODO
|
mem_free_total = 1024 * 1024 * 1024 #TODO
|
||||||
mem_free_torch = mem_free_total
|
mem_free_torch = mem_free_total
|
||||||
elif xpu_available:
|
elif xpu_available:
|
||||||
mem_free_total = torch.xpu.get_device_properties(dev).total_memory - torch.xpu.memory_allocated(dev)
|
stats = torch.xpu.memory_stats(dev)
|
||||||
mem_free_torch = mem_free_total
|
mem_active = stats['active_bytes.all.current']
|
||||||
|
mem_allocated = stats['allocated_bytes.all.current']
|
||||||
|
mem_reserved = stats['reserved_bytes.all.current']
|
||||||
|
mem_free_torch = mem_reserved - mem_active
|
||||||
|
mem_free_total = torch.xpu.get_device_properties(dev).total_memory - mem_allocated
|
||||||
else:
|
else:
|
||||||
stats = torch.cuda.memory_stats(dev)
|
stats = torch.cuda.memory_stats(dev)
|
||||||
mem_active = stats['active_bytes.all.current']
|
mem_active = stats['active_bytes.all.current']
|
||||||
@ -570,9 +586,12 @@ def should_use_fp16(device=None, model_params=0):
|
|||||||
if directml_enabled:
|
if directml_enabled:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if cpu_mode() or mps_mode() or xpu_available:
|
if cpu_mode() or mps_mode():
|
||||||
return False #TODO ?
|
return False #TODO ?
|
||||||
|
|
||||||
|
if xpu_available:
|
||||||
|
return True
|
||||||
|
|
||||||
if torch.cuda.is_bf16_supported():
|
if torch.cuda.is_bf16_supported():
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|||||||
181
comfy/sd.py
181
comfy/sd.py
@ -2,6 +2,7 @@ import torch
|
|||||||
import contextlib
|
import contextlib
|
||||||
import copy
|
import copy
|
||||||
import inspect
|
import inspect
|
||||||
|
import math
|
||||||
|
|
||||||
from comfy import model_management
|
from comfy import model_management
|
||||||
from .ldm.util import instantiate_from_config
|
from .ldm.util import instantiate_from_config
|
||||||
@ -243,6 +244,13 @@ def set_attr(obj, attr, value):
|
|||||||
setattr(obj, attrs[-1], torch.nn.Parameter(value))
|
setattr(obj, attrs[-1], torch.nn.Parameter(value))
|
||||||
del prev
|
del prev
|
||||||
|
|
||||||
|
def get_attr(obj, attr):
|
||||||
|
attrs = attr.split(".")
|
||||||
|
for name in attrs:
|
||||||
|
obj = getattr(obj, name)
|
||||||
|
return obj
|
||||||
|
|
||||||
|
|
||||||
class ModelPatcher:
|
class ModelPatcher:
|
||||||
def __init__(self, model, load_device, offload_device, size=0, current_device=None):
|
def __init__(self, model, load_device, offload_device, size=0, current_device=None):
|
||||||
self.size = size
|
self.size = size
|
||||||
@ -649,7 +657,7 @@ class VAE:
|
|||||||
def decode(self, samples_in):
|
def decode(self, samples_in):
|
||||||
self.first_stage_model = self.first_stage_model.to(self.device)
|
self.first_stage_model = self.first_stage_model.to(self.device)
|
||||||
try:
|
try:
|
||||||
memory_used = (2562 * samples_in.shape[2] * samples_in.shape[3] * 64) * 1.4
|
memory_used = (2562 * samples_in.shape[2] * samples_in.shape[3] * 64) * 1.7
|
||||||
model_management.free_memory(memory_used, self.device)
|
model_management.free_memory(memory_used, self.device)
|
||||||
free_memory = model_management.get_free_memory(self.device)
|
free_memory = model_management.get_free_memory(self.device)
|
||||||
batch_number = int(free_memory / memory_used)
|
batch_number = int(free_memory / memory_used)
|
||||||
@ -677,7 +685,7 @@ class VAE:
|
|||||||
self.first_stage_model = self.first_stage_model.to(self.device)
|
self.first_stage_model = self.first_stage_model.to(self.device)
|
||||||
pixel_samples = pixel_samples.movedim(-1,1)
|
pixel_samples = pixel_samples.movedim(-1,1)
|
||||||
try:
|
try:
|
||||||
memory_used = (2078 * pixel_samples.shape[2] * pixel_samples.shape[3]) * 1.4 #NOTE: this constant along with the one in the decode above are estimated from the mem usage for the VAE and could change.
|
memory_used = (2078 * pixel_samples.shape[2] * pixel_samples.shape[3]) * 1.7 #NOTE: this constant along with the one in the decode above are estimated from the mem usage for the VAE and could change.
|
||||||
model_management.free_memory(memory_used, self.device)
|
model_management.free_memory(memory_used, self.device)
|
||||||
free_memory = model_management.get_free_memory(self.device)
|
free_memory = model_management.get_free_memory(self.device)
|
||||||
batch_number = int(free_memory / memory_used)
|
batch_number = int(free_memory / memory_used)
|
||||||
@ -735,6 +743,7 @@ class ControlBase:
|
|||||||
device = model_management.get_torch_device()
|
device = model_management.get_torch_device()
|
||||||
self.device = device
|
self.device = device
|
||||||
self.previous_controlnet = None
|
self.previous_controlnet = None
|
||||||
|
self.global_average_pooling = False
|
||||||
|
|
||||||
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(1.0, 0.0)):
|
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(1.0, 0.0)):
|
||||||
self.cond_hint_original = cond_hint
|
self.cond_hint_original = cond_hint
|
||||||
@ -770,6 +779,51 @@ class ControlBase:
|
|||||||
c.strength = self.strength
|
c.strength = self.strength
|
||||||
c.timestep_percent_range = self.timestep_percent_range
|
c.timestep_percent_range = self.timestep_percent_range
|
||||||
|
|
||||||
|
def control_merge(self, control_input, control_output, control_prev, output_dtype):
|
||||||
|
out = {'input':[], 'middle':[], 'output': []}
|
||||||
|
|
||||||
|
if control_input is not None:
|
||||||
|
for i in range(len(control_input)):
|
||||||
|
key = 'input'
|
||||||
|
x = control_input[i]
|
||||||
|
if x is not None:
|
||||||
|
x *= self.strength
|
||||||
|
if x.dtype != output_dtype:
|
||||||
|
x = x.to(output_dtype)
|
||||||
|
out[key].insert(0, x)
|
||||||
|
|
||||||
|
if control_output is not None:
|
||||||
|
for i in range(len(control_output)):
|
||||||
|
if i == (len(control_output) - 1):
|
||||||
|
key = 'middle'
|
||||||
|
index = 0
|
||||||
|
else:
|
||||||
|
key = 'output'
|
||||||
|
index = i
|
||||||
|
x = control_output[i]
|
||||||
|
if x is not None:
|
||||||
|
if self.global_average_pooling:
|
||||||
|
x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3])
|
||||||
|
|
||||||
|
x *= self.strength
|
||||||
|
if x.dtype != output_dtype:
|
||||||
|
x = x.to(output_dtype)
|
||||||
|
|
||||||
|
out[key].append(x)
|
||||||
|
if control_prev is not None:
|
||||||
|
for x in ['input', 'middle', 'output']:
|
||||||
|
o = out[x]
|
||||||
|
for i in range(len(control_prev[x])):
|
||||||
|
prev_val = control_prev[x][i]
|
||||||
|
if i >= len(o):
|
||||||
|
o.append(prev_val)
|
||||||
|
elif prev_val is not None:
|
||||||
|
if o[i] is None:
|
||||||
|
o[i] = prev_val
|
||||||
|
else:
|
||||||
|
o[i] += prev_val
|
||||||
|
return out
|
||||||
|
|
||||||
class ControlNet(ControlBase):
|
class ControlNet(ControlBase):
|
||||||
def __init__(self, control_model, global_average_pooling=False, device=None):
|
def __init__(self, control_model, global_average_pooling=False, device=None):
|
||||||
super().__init__(device)
|
super().__init__(device)
|
||||||
@ -798,41 +852,13 @@ class ControlNet(ControlBase):
|
|||||||
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
||||||
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
||||||
|
|
||||||
if self.control_model.dtype == torch.float16:
|
|
||||||
precision_scope = torch.autocast
|
|
||||||
else:
|
|
||||||
precision_scope = contextlib.nullcontext
|
|
||||||
|
|
||||||
with precision_scope(model_management.get_autocast_device(self.device)):
|
context = torch.cat(cond['c_crossattn'], 1)
|
||||||
context = torch.cat(cond['c_crossattn'], 1)
|
y = cond.get('c_adm', None)
|
||||||
y = cond.get('c_adm', None)
|
if y is not None:
|
||||||
control = self.control_model(x=x_noisy, hint=self.cond_hint, timesteps=t, context=context, y=y)
|
y = y.to(self.control_model.dtype)
|
||||||
out = {'middle':[], 'output': []}
|
control = self.control_model(x=x_noisy.to(self.control_model.dtype), hint=self.cond_hint, timesteps=t, context=context.to(self.control_model.dtype), y=y)
|
||||||
autocast_enabled = torch.is_autocast_enabled()
|
return self.control_merge(None, control, control_prev, output_dtype)
|
||||||
|
|
||||||
for i in range(len(control)):
|
|
||||||
if i == (len(control) - 1):
|
|
||||||
key = 'middle'
|
|
||||||
index = 0
|
|
||||||
else:
|
|
||||||
key = 'output'
|
|
||||||
index = i
|
|
||||||
x = control[i]
|
|
||||||
if self.global_average_pooling:
|
|
||||||
x = torch.mean(x, dim=(2, 3), keepdim=True).repeat(1, 1, x.shape[2], x.shape[3])
|
|
||||||
|
|
||||||
x *= self.strength
|
|
||||||
if x.dtype != output_dtype and not autocast_enabled:
|
|
||||||
x = x.to(output_dtype)
|
|
||||||
|
|
||||||
if control_prev is not None and key in control_prev:
|
|
||||||
prev = control_prev[key][index]
|
|
||||||
if prev is not None:
|
|
||||||
x += prev
|
|
||||||
out[key].append(x)
|
|
||||||
if control_prev is not None and 'input' in control_prev:
|
|
||||||
out['input'] = control_prev['input']
|
|
||||||
return out
|
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling)
|
c = ControlNet(self.control_model, global_average_pooling=self.global_average_pooling)
|
||||||
@ -859,9 +885,9 @@ class ControlLoraOps:
|
|||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
if self.up is not None:
|
if self.up is not None:
|
||||||
return torch.nn.functional.linear(input, self.weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(self.weight.dtype), self.bias)
|
return torch.nn.functional.linear(input, self.weight.to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias)
|
||||||
else:
|
else:
|
||||||
return torch.nn.functional.linear(input, self.weight, self.bias)
|
return torch.nn.functional.linear(input, self.weight.to(input.device), self.bias)
|
||||||
|
|
||||||
class Conv2d(torch.nn.Module):
|
class Conv2d(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -898,9 +924,9 @@ class ControlLoraOps:
|
|||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
if self.up is not None:
|
if self.up is not None:
|
||||||
return torch.nn.functional.conv2d(input, self.weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(self.weight.dtype), self.bias, self.stride, self.padding, self.dilation, self.groups)
|
return torch.nn.functional.conv2d(input, self.weight.to(input.device) + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), self.bias, self.stride, self.padding, self.dilation, self.groups)
|
||||||
else:
|
else:
|
||||||
return torch.nn.functional.conv2d(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
|
return torch.nn.functional.conv2d(input, self.weight.to(input.device), self.bias, self.stride, self.padding, self.dilation, self.groups)
|
||||||
|
|
||||||
def conv_nd(self, dims, *args, **kwargs):
|
def conv_nd(self, dims, *args, **kwargs):
|
||||||
if dims == 2:
|
if dims == 2:
|
||||||
@ -922,22 +948,28 @@ class ControlLora(ControlNet):
|
|||||||
controlnet_config["hint_channels"] = self.control_weights["input_hint_block.0.weight"].shape[1]
|
controlnet_config["hint_channels"] = self.control_weights["input_hint_block.0.weight"].shape[1]
|
||||||
controlnet_config["operations"] = ControlLoraOps()
|
controlnet_config["operations"] = ControlLoraOps()
|
||||||
self.control_model = cldm.ControlNet(**controlnet_config)
|
self.control_model = cldm.ControlNet(**controlnet_config)
|
||||||
if model_management.should_use_fp16():
|
dtype = model.get_dtype()
|
||||||
self.control_model.half()
|
self.control_model.to(dtype)
|
||||||
self.control_model.to(model_management.get_torch_device())
|
self.control_model.to(model_management.get_torch_device())
|
||||||
diffusion_model = model.diffusion_model
|
diffusion_model = model.diffusion_model
|
||||||
sd = diffusion_model.state_dict()
|
sd = diffusion_model.state_dict()
|
||||||
cm = self.control_model.state_dict()
|
cm = self.control_model.state_dict()
|
||||||
|
|
||||||
for k in sd:
|
for k in sd:
|
||||||
|
weight = sd[k]
|
||||||
|
if weight.device == torch.device("meta"): #lowvram NOTE: this depends on the inner working of the accelerate library so it might break.
|
||||||
|
key_split = k.split('.') # I have no idea why they don't just leave the weight there instead of using the meta device.
|
||||||
|
op = get_attr(diffusion_model, '.'.join(key_split[:-1]))
|
||||||
|
weight = op._hf_hook.weights_map[key_split[-1]]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
set_attr(self.control_model, k, sd[k])
|
set_attr(self.control_model, k, weight)
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
for k in self.control_weights:
|
for k in self.control_weights:
|
||||||
if k not in {"lora_controlnet"}:
|
if k not in {"lora_controlnet"}:
|
||||||
set_attr(self.control_model, k, self.control_weights[k].to(model_management.get_torch_device()))
|
set_attr(self.control_model, k, self.control_weights[k].to(dtype).to(model_management.get_torch_device()))
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
c = ControlLora(self.control_weights, global_average_pooling=self.global_average_pooling)
|
c = ControlLora(self.control_weights, global_average_pooling=self.global_average_pooling)
|
||||||
@ -1068,6 +1100,12 @@ class T2IAdapter(ControlBase):
|
|||||||
self.channels_in = channels_in
|
self.channels_in = channels_in
|
||||||
self.control_input = None
|
self.control_input = None
|
||||||
|
|
||||||
|
def scale_image_to(self, width, height):
|
||||||
|
unshuffle_amount = self.t2i_model.unshuffle_amount
|
||||||
|
width = math.ceil(width / unshuffle_amount) * unshuffle_amount
|
||||||
|
height = math.ceil(height / unshuffle_amount) * unshuffle_amount
|
||||||
|
return width, height
|
||||||
|
|
||||||
def get_control(self, x_noisy, t, cond, batched_number):
|
def get_control(self, x_noisy, t, cond, batched_number):
|
||||||
control_prev = None
|
control_prev = None
|
||||||
if self.previous_controlnet is not None:
|
if self.previous_controlnet is not None:
|
||||||
@ -1085,44 +1123,24 @@ class T2IAdapter(ControlBase):
|
|||||||
del self.cond_hint
|
del self.cond_hint
|
||||||
self.control_input = None
|
self.control_input = None
|
||||||
self.cond_hint = None
|
self.cond_hint = None
|
||||||
self.cond_hint = utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * 8, x_noisy.shape[2] * 8, 'nearest-exact', "center").float().to(self.device)
|
width, height = self.scale_image_to(x_noisy.shape[3] * 8, x_noisy.shape[2] * 8)
|
||||||
|
self.cond_hint = utils.common_upscale(self.cond_hint_original, width, height, 'nearest-exact', "center").float().to(self.device)
|
||||||
if self.channels_in == 1 and self.cond_hint.shape[1] > 1:
|
if self.channels_in == 1 and self.cond_hint.shape[1] > 1:
|
||||||
self.cond_hint = torch.mean(self.cond_hint, 1, keepdim=True)
|
self.cond_hint = torch.mean(self.cond_hint, 1, keepdim=True)
|
||||||
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
if x_noisy.shape[0] != self.cond_hint.shape[0]:
|
||||||
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
self.cond_hint = broadcast_image_to(self.cond_hint, x_noisy.shape[0], batched_number)
|
||||||
if self.control_input is None:
|
if self.control_input is None:
|
||||||
|
self.t2i_model.to(x_noisy.dtype)
|
||||||
self.t2i_model.to(self.device)
|
self.t2i_model.to(self.device)
|
||||||
self.control_input = self.t2i_model(self.cond_hint)
|
self.control_input = self.t2i_model(self.cond_hint.to(x_noisy.dtype))
|
||||||
self.t2i_model.cpu()
|
self.t2i_model.cpu()
|
||||||
|
|
||||||
output_dtype = x_noisy.dtype
|
control_input = list(map(lambda a: None if a is None else a.clone(), self.control_input))
|
||||||
out = {'input':[]}
|
mid = None
|
||||||
|
if self.t2i_model.xl == True:
|
||||||
autocast_enabled = torch.is_autocast_enabled()
|
mid = control_input[-1:]
|
||||||
for i in range(len(self.control_input)):
|
control_input = control_input[:-1]
|
||||||
key = 'input'
|
return self.control_merge(control_input, mid, control_prev, x_noisy.dtype)
|
||||||
x = self.control_input[i] * self.strength
|
|
||||||
if x.dtype != output_dtype and not autocast_enabled:
|
|
||||||
x = x.to(output_dtype)
|
|
||||||
|
|
||||||
if control_prev is not None and key in control_prev:
|
|
||||||
index = len(control_prev[key]) - i * 3 - 3
|
|
||||||
prev = control_prev[key][index]
|
|
||||||
if prev is not None:
|
|
||||||
x += prev
|
|
||||||
out[key].insert(0, None)
|
|
||||||
out[key].insert(0, None)
|
|
||||||
out[key].insert(0, x)
|
|
||||||
|
|
||||||
if control_prev is not None and 'input' in control_prev:
|
|
||||||
for i in range(len(out['input'])):
|
|
||||||
if out['input'][i] is None:
|
|
||||||
out['input'][i] = control_prev['input'][i]
|
|
||||||
if control_prev is not None and 'middle' in control_prev:
|
|
||||||
out['middle'] = control_prev['middle']
|
|
||||||
if control_prev is not None and 'output' in control_prev:
|
|
||||||
out['output'] = control_prev['output']
|
|
||||||
return out
|
|
||||||
|
|
||||||
def copy(self):
|
def copy(self):
|
||||||
c = T2IAdapter(self.t2i_model, self.channels_in)
|
c = T2IAdapter(self.t2i_model, self.channels_in)
|
||||||
@ -1145,11 +1163,20 @@ def load_t2i_adapter(t2i_data):
|
|||||||
down_opts = list(filter(lambda a: a.endswith("down_opt.op.weight"), keys))
|
down_opts = list(filter(lambda a: a.endswith("down_opt.op.weight"), keys))
|
||||||
if len(down_opts) > 0:
|
if len(down_opts) > 0:
|
||||||
use_conv = True
|
use_conv = True
|
||||||
model_ad = adapter.Adapter(cin=cin, channels=[channel, channel*2, channel*4, channel*4][:4], nums_rb=2, ksize=ksize, sk=True, use_conv=use_conv)
|
xl = False
|
||||||
|
if cin == 256:
|
||||||
|
xl = True
|
||||||
|
model_ad = adapter.Adapter(cin=cin, channels=[channel, channel*2, channel*4, channel*4][:4], nums_rb=2, ksize=ksize, sk=True, use_conv=use_conv, xl=xl)
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
model_ad.load_state_dict(t2i_data)
|
missing, unexpected = model_ad.load_state_dict(t2i_data)
|
||||||
return T2IAdapter(model_ad, cin // 64)
|
if len(missing) > 0:
|
||||||
|
print("t2i missing", missing)
|
||||||
|
|
||||||
|
if len(unexpected) > 0:
|
||||||
|
print("t2i unexpected", unexpected)
|
||||||
|
|
||||||
|
return T2IAdapter(model_ad, model_ad.input_channels)
|
||||||
|
|
||||||
|
|
||||||
class StyleModel:
|
class StyleModel:
|
||||||
|
|||||||
@ -101,17 +101,30 @@ class ResnetBlock(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class Adapter(nn.Module):
|
class Adapter(nn.Module):
|
||||||
def __init__(self, channels=[320, 640, 1280, 1280], nums_rb=3, cin=64, ksize=3, sk=False, use_conv=True):
|
def __init__(self, channels=[320, 640, 1280, 1280], nums_rb=3, cin=64, ksize=3, sk=False, use_conv=True, xl=True):
|
||||||
super(Adapter, self).__init__()
|
super(Adapter, self).__init__()
|
||||||
self.unshuffle = nn.PixelUnshuffle(8)
|
self.unshuffle_amount = 8
|
||||||
|
resblock_no_downsample = []
|
||||||
|
resblock_downsample = [3, 2, 1]
|
||||||
|
self.xl = xl
|
||||||
|
if self.xl:
|
||||||
|
self.unshuffle_amount = 16
|
||||||
|
resblock_no_downsample = [1]
|
||||||
|
resblock_downsample = [2]
|
||||||
|
|
||||||
|
self.input_channels = cin // (self.unshuffle_amount * self.unshuffle_amount)
|
||||||
|
self.unshuffle = nn.PixelUnshuffle(self.unshuffle_amount)
|
||||||
self.channels = channels
|
self.channels = channels
|
||||||
self.nums_rb = nums_rb
|
self.nums_rb = nums_rb
|
||||||
self.body = []
|
self.body = []
|
||||||
for i in range(len(channels)):
|
for i in range(len(channels)):
|
||||||
for j in range(nums_rb):
|
for j in range(nums_rb):
|
||||||
if (i != 0) and (j == 0):
|
if (i in resblock_downsample) and (j == 0):
|
||||||
self.body.append(
|
self.body.append(
|
||||||
ResnetBlock(channels[i - 1], channels[i], down=True, ksize=ksize, sk=sk, use_conv=use_conv))
|
ResnetBlock(channels[i - 1], channels[i], down=True, ksize=ksize, sk=sk, use_conv=use_conv))
|
||||||
|
elif (i in resblock_no_downsample) and (j == 0):
|
||||||
|
self.body.append(
|
||||||
|
ResnetBlock(channels[i - 1], channels[i], down=False, ksize=ksize, sk=sk, use_conv=use_conv))
|
||||||
else:
|
else:
|
||||||
self.body.append(
|
self.body.append(
|
||||||
ResnetBlock(channels[i], channels[i], down=False, ksize=ksize, sk=sk, use_conv=use_conv))
|
ResnetBlock(channels[i], channels[i], down=False, ksize=ksize, sk=sk, use_conv=use_conv))
|
||||||
@ -128,6 +141,16 @@ class Adapter(nn.Module):
|
|||||||
for j in range(self.nums_rb):
|
for j in range(self.nums_rb):
|
||||||
idx = i * self.nums_rb + j
|
idx = i * self.nums_rb + j
|
||||||
x = self.body[idx](x)
|
x = self.body[idx](x)
|
||||||
|
if self.xl:
|
||||||
|
features.append(None)
|
||||||
|
if i == 0:
|
||||||
|
features.append(None)
|
||||||
|
features.append(None)
|
||||||
|
if i == 2:
|
||||||
|
features.append(None)
|
||||||
|
else:
|
||||||
|
features.append(None)
|
||||||
|
features.append(None)
|
||||||
features.append(x)
|
features.append(x)
|
||||||
|
|
||||||
return features
|
return features
|
||||||
@ -241,10 +264,14 @@ class extractor(nn.Module):
|
|||||||
class Adapter_light(nn.Module):
|
class Adapter_light(nn.Module):
|
||||||
def __init__(self, channels=[320, 640, 1280, 1280], nums_rb=3, cin=64):
|
def __init__(self, channels=[320, 640, 1280, 1280], nums_rb=3, cin=64):
|
||||||
super(Adapter_light, self).__init__()
|
super(Adapter_light, self).__init__()
|
||||||
self.unshuffle = nn.PixelUnshuffle(8)
|
self.unshuffle_amount = 8
|
||||||
|
self.unshuffle = nn.PixelUnshuffle(self.unshuffle_amount)
|
||||||
|
self.input_channels = cin // (self.unshuffle_amount * self.unshuffle_amount)
|
||||||
self.channels = channels
|
self.channels = channels
|
||||||
self.nums_rb = nums_rb
|
self.nums_rb = nums_rb
|
||||||
self.body = []
|
self.body = []
|
||||||
|
self.xl = False
|
||||||
|
|
||||||
for i in range(len(channels)):
|
for i in range(len(channels)):
|
||||||
if i == 0:
|
if i == 0:
|
||||||
self.body.append(extractor(in_c=cin, inter_c=channels[i]//4, out_c=channels[i], nums_rb=nums_rb, down=False))
|
self.body.append(extractor(in_c=cin, inter_c=channels[i]//4, out_c=channels[i], nums_rb=nums_rb, down=False))
|
||||||
@ -259,6 +286,8 @@ class Adapter_light(nn.Module):
|
|||||||
features = []
|
features = []
|
||||||
for i in range(len(self.channels)):
|
for i in range(len(self.channels)):
|
||||||
x = self.body[i](x)
|
x = self.body[i](x)
|
||||||
|
features.append(None)
|
||||||
|
features.append(None)
|
||||||
features.append(x)
|
features.append(x)
|
||||||
|
|
||||||
return features
|
return features
|
||||||
|
|||||||
15
nodes.py
15
nodes.py
@ -1306,7 +1306,7 @@ class LoadImage:
|
|||||||
input_dir = folder_paths.get_input_directory()
|
input_dir = folder_paths.get_input_directory()
|
||||||
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
|
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
|
||||||
return {"required":
|
return {"required":
|
||||||
{"image": (sorted(files), )},
|
{"image": (sorted(files), {"image_upload": True})},
|
||||||
}
|
}
|
||||||
|
|
||||||
CATEGORY = "image"
|
CATEGORY = "image"
|
||||||
@ -1349,7 +1349,7 @@ class LoadImageMask:
|
|||||||
input_dir = folder_paths.get_input_directory()
|
input_dir = folder_paths.get_input_directory()
|
||||||
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
|
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
|
||||||
return {"required":
|
return {"required":
|
||||||
{"image": (sorted(files), ),
|
{"image": (sorted(files), {"image_upload": True}),
|
||||||
"channel": (s._color_channels, ), }
|
"channel": (s._color_channels, ), }
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1673,6 +1673,8 @@ NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"VAEEncodeTiled": "VAE Encode (Tiled)",
|
"VAEEncodeTiled": "VAE Encode (Tiled)",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
EXTENSION_WEB_DIRS = {}
|
||||||
|
|
||||||
def load_custom_node(module_path, ignore=set()):
|
def load_custom_node(module_path, ignore=set()):
|
||||||
module_name = os.path.basename(module_path)
|
module_name = os.path.basename(module_path)
|
||||||
if os.path.isfile(module_path):
|
if os.path.isfile(module_path):
|
||||||
@ -1681,11 +1683,20 @@ def load_custom_node(module_path, ignore=set()):
|
|||||||
try:
|
try:
|
||||||
if os.path.isfile(module_path):
|
if os.path.isfile(module_path):
|
||||||
module_spec = importlib.util.spec_from_file_location(module_name, module_path)
|
module_spec = importlib.util.spec_from_file_location(module_name, module_path)
|
||||||
|
module_dir = os.path.split(module_path)[0]
|
||||||
else:
|
else:
|
||||||
module_spec = importlib.util.spec_from_file_location(module_name, os.path.join(module_path, "__init__.py"))
|
module_spec = importlib.util.spec_from_file_location(module_name, os.path.join(module_path, "__init__.py"))
|
||||||
|
module_dir = module_path
|
||||||
|
|
||||||
module = importlib.util.module_from_spec(module_spec)
|
module = importlib.util.module_from_spec(module_spec)
|
||||||
sys.modules[module_name] = module
|
sys.modules[module_name] = module
|
||||||
module_spec.loader.exec_module(module)
|
module_spec.loader.exec_module(module)
|
||||||
|
|
||||||
|
if hasattr(module, "WEB_DIRECTORY") and getattr(module, "WEB_DIRECTORY") is not None:
|
||||||
|
web_dir = os.path.abspath(os.path.join(module_dir, getattr(module, "WEB_DIRECTORY")))
|
||||||
|
if os.path.isdir(web_dir):
|
||||||
|
EXTENSION_WEB_DIRS[module_name] = web_dir
|
||||||
|
|
||||||
if hasattr(module, "NODE_CLASS_MAPPINGS") and getattr(module, "NODE_CLASS_MAPPINGS") is not None:
|
if hasattr(module, "NODE_CLASS_MAPPINGS") and getattr(module, "NODE_CLASS_MAPPINGS") is not None:
|
||||||
for name in module.NODE_CLASS_MAPPINGS:
|
for name in module.NODE_CLASS_MAPPINGS:
|
||||||
if name not in ignore:
|
if name not in ignore:
|
||||||
|
|||||||
@ -75,6 +75,8 @@
|
|||||||
"#!wget -c https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_base_1.0.safetensors -P ./models/checkpoints/\n",
|
"#!wget -c https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/resolve/main/sd_xl_base_1.0.safetensors -P ./models/checkpoints/\n",
|
||||||
"#!wget -c https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0/resolve/main/sd_xl_refiner_1.0.safetensors -P ./models/checkpoints/\n",
|
"#!wget -c https://huggingface.co/stabilityai/stable-diffusion-xl-refiner-1.0/resolve/main/sd_xl_refiner_1.0.safetensors -P ./models/checkpoints/\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"# SDXL ReVision\n",
|
||||||
|
"#!wget -c https://huggingface.co/comfyanonymous/clip_vision_g/resolve/main/clip_vision_g.safetensors -P ./models/clip_vision/\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# SD1.5\n",
|
"# SD1.5\n",
|
||||||
"!wget -c https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.ckpt -P ./models/checkpoints/\n",
|
"!wget -c https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned-emaonly.ckpt -P ./models/checkpoints/\n",
|
||||||
@ -142,6 +144,11 @@
|
|||||||
"#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15s2_lineart_anime_fp16.safetensors -P ./models/controlnet/\n",
|
"#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11p_sd15s2_lineart_anime_fp16.safetensors -P ./models/controlnet/\n",
|
||||||
"#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11u_sd15_tile_fp16.safetensors -P ./models/controlnet/\n",
|
"#!wget -c https://huggingface.co/comfyanonymous/ControlNet-v1-1_fp16_safetensors/resolve/main/control_v11u_sd15_tile_fp16.safetensors -P ./models/controlnet/\n",
|
||||||
"\n",
|
"\n",
|
||||||
|
"# ControlNet SDXL\n",
|
||||||
|
"#!wget -c https://huggingface.co/stabilityai/control-lora/resolve/main/control-LoRAs-rank256/control-lora-canny-rank256.safetensors -P ./models/controlnet/\n",
|
||||||
|
"#!wget -c https://huggingface.co/stabilityai/control-lora/resolve/main/control-LoRAs-rank256/control-lora-depth-rank256.safetensors -P ./models/controlnet/\n",
|
||||||
|
"#!wget -c https://huggingface.co/stabilityai/control-lora/resolve/main/control-LoRAs-rank256/control-lora-recolor-rank256.safetensors -P ./models/controlnet/\n",
|
||||||
|
"#!wget -c https://huggingface.co/stabilityai/control-lora/resolve/main/control-LoRAs-rank256/control-lora-sketch-rank256.safetensors -P ./models/controlnet/\n",
|
||||||
"\n",
|
"\n",
|
||||||
"# Controlnet Preprocessor nodes by Fannovel16\n",
|
"# Controlnet Preprocessor nodes by Fannovel16\n",
|
||||||
"#!cd custom_nodes && git clone https://github.com/Fannovel16/comfy_controlnet_preprocessors; cd comfy_controlnet_preprocessors && python install.py\n",
|
"#!cd custom_nodes && git clone https://github.com/Fannovel16/comfy_controlnet_preprocessors; cd comfy_controlnet_preprocessors && python install.py\n",
|
||||||
|
|||||||
22
server.py
22
server.py
@ -5,6 +5,7 @@ import nodes
|
|||||||
import folder_paths
|
import folder_paths
|
||||||
import execution
|
import execution
|
||||||
import uuid
|
import uuid
|
||||||
|
import urllib
|
||||||
import json
|
import json
|
||||||
import glob
|
import glob
|
||||||
import struct
|
import struct
|
||||||
@ -67,6 +68,8 @@ class PromptServer():
|
|||||||
|
|
||||||
mimetypes.init()
|
mimetypes.init()
|
||||||
mimetypes.types_map['.js'] = 'application/javascript; charset=utf-8'
|
mimetypes.types_map['.js'] = 'application/javascript; charset=utf-8'
|
||||||
|
|
||||||
|
self.supports = ["custom_nodes_from_web"]
|
||||||
self.prompt_queue = None
|
self.prompt_queue = None
|
||||||
self.loop = loop
|
self.loop = loop
|
||||||
self.messages = asyncio.Queue()
|
self.messages = asyncio.Queue()
|
||||||
@ -123,8 +126,17 @@ class PromptServer():
|
|||||||
|
|
||||||
@routes.get("/extensions")
|
@routes.get("/extensions")
|
||||||
async def get_extensions(request):
|
async def get_extensions(request):
|
||||||
files = glob.glob(os.path.join(self.web_root, 'extensions/**/*.js'), recursive=True)
|
files = glob.glob(os.path.join(
|
||||||
return web.json_response(list(map(lambda f: "/" + os.path.relpath(f, self.web_root).replace("\\", "/"), files)))
|
self.web_root, 'extensions/**/*.js'), recursive=True)
|
||||||
|
|
||||||
|
extensions = list(map(lambda f: "/" + os.path.relpath(f, self.web_root).replace("\\", "/"), files))
|
||||||
|
|
||||||
|
for name, dir in nodes.EXTENSION_WEB_DIRS.items():
|
||||||
|
files = glob.glob(os.path.join(dir, '**/*.js'), recursive=True)
|
||||||
|
extensions.extend(list(map(lambda f: "/extensions/" + urllib.parse.quote(
|
||||||
|
name) + "/" + os.path.relpath(f, dir).replace("\\", "/"), files)))
|
||||||
|
|
||||||
|
return web.json_response(extensions)
|
||||||
|
|
||||||
def get_dir_by_type(dir_type):
|
def get_dir_by_type(dir_type):
|
||||||
if dir_type is None:
|
if dir_type is None:
|
||||||
@ -492,6 +504,12 @@ class PromptServer():
|
|||||||
|
|
||||||
def add_routes(self):
|
def add_routes(self):
|
||||||
self.app.add_routes(self.routes)
|
self.app.add_routes(self.routes)
|
||||||
|
|
||||||
|
for name, dir in nodes.EXTENSION_WEB_DIRS.items():
|
||||||
|
self.app.add_routes([
|
||||||
|
web.static('/extensions/' + urllib.parse.quote(name), dir, follow_symlinks=True),
|
||||||
|
])
|
||||||
|
|
||||||
self.app.add_routes([
|
self.app.add_routes([
|
||||||
web.static('/', self.web_root, follow_symlinks=True),
|
web.static('/', self.web_root, follow_symlinks=True),
|
||||||
])
|
])
|
||||||
|
|||||||
@ -5,7 +5,8 @@ import { app } from "../../scripts/app.js";
|
|||||||
app.registerExtension({
|
app.registerExtension({
|
||||||
name: "Comfy.UploadImage",
|
name: "Comfy.UploadImage",
|
||||||
async beforeRegisterNodeDef(nodeType, nodeData, app) {
|
async beforeRegisterNodeDef(nodeType, nodeData, app) {
|
||||||
if (nodeData.name === "LoadImage" || nodeData.name === "LoadImageMask") {
|
console.log(nodeData);
|
||||||
|
if (nodeData?.input?.required?.image?.[1]?.image_upload === true) {
|
||||||
nodeData.input.required.upload = ["IMAGEUPLOAD"];
|
nodeData.input.required.upload = ["IMAGEUPLOAD"];
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|||||||
@ -1026,18 +1026,21 @@ export class ComfyApp {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Loads all extensions from the API into the window
|
* Loads all extensions from the API into the window in parallel
|
||||||
*/
|
*/
|
||||||
async #loadExtensions() {
|
async #loadExtensions() {
|
||||||
const extensions = await api.getExtensions();
|
const extensions = await api.getExtensions();
|
||||||
this.logging.addEntry("Comfy.App", "debug", { Extensions: extensions });
|
this.logging.addEntry("Comfy.App", "debug", { Extensions: extensions });
|
||||||
for (const ext of extensions) {
|
|
||||||
try {
|
const extensionPromises = extensions.map(async ext => {
|
||||||
await import(api.apiURL(ext));
|
try {
|
||||||
} catch (error) {
|
await import(api.apiURL(ext));
|
||||||
console.error("Error loading extension", ext, error);
|
} catch (error) {
|
||||||
}
|
console.error("Error loading extension", ext, error);
|
||||||
}
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
await Promise.all(extensionPromises);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user