Merge branch 'comfyanonymous:master' into feature/blockweights

This commit is contained in:
Dr.Lt.Data 2023-05-06 09:43:41 +09:00 committed by GitHub
commit 8b39a6d49b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 45 additions and 16 deletions

View File

@ -242,14 +242,28 @@ class Gligen(nn.Module):
self.position_net = position_net
self.key_dim = key_dim
self.max_objs = 30
self.lowvram = False
def _set_position(self, boxes, masks, positive_embeddings):
if self.lowvram == True:
self.position_net.to(boxes.device)
objs = self.position_net(boxes, masks, positive_embeddings)
def func(key, x):
module = self.module_list[key]
return module(x, objs)
return func
if self.lowvram == True:
self.position_net.cpu()
def func_lowvram(key, x):
module = self.module_list[key]
module.to(x.device)
r = module(x, objs)
module.cpu()
return r
return func_lowvram
else:
def func(key, x):
module = self.module_list[key]
return module(x, objs)
return func
def set_position(self, latent_image_shape, position_params, device):
batch, c, h, w = latent_image_shape
@ -294,8 +308,11 @@ class Gligen(nn.Module):
masks.to(device),
conds.to(device))
def set_lowvram(self, value=True):
self.lowvram = value
def cleanup(self):
pass
self.lowvram = False
def get_models(self):
return [self]

View File

@ -572,9 +572,6 @@ class BasicTransformerBlock(nn.Module):
x += n
x = self.ff(self.norm3(x)) + x
if current_index is not None:
transformer_options["current_index"] += 1
return x

View File

@ -88,6 +88,19 @@ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
x = layer(x)
return x
#This is needed because accelerate makes a copy of transformer_options which breaks "current_index"
def forward_timestep_embed(ts, x, emb, context=None, transformer_options={}, output_shape=None):
for layer in ts:
if isinstance(layer, TimestepBlock):
x = layer(x, emb)
elif isinstance(layer, SpatialTransformer):
x = layer(x, context, transformer_options)
transformer_options["current_index"] += 1
elif isinstance(layer, Upsample):
x = layer(x, output_shape=output_shape)
else:
x = layer(x)
return x
class Upsample(nn.Module):
"""
@ -805,13 +818,13 @@ class UNetModel(nn.Module):
h = x.type(self.dtype)
for id, module in enumerate(self.input_blocks):
h = module(h, emb, context, transformer_options)
h = forward_timestep_embed(module, h, emb, context, transformer_options)
if control is not None and 'input' in control and len(control['input']) > 0:
ctrl = control['input'].pop()
if ctrl is not None:
h += ctrl
hs.append(h)
h = self.middle_block(h, emb, context, transformer_options)
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options)
if control is not None and 'middle' in control and len(control['middle']) > 0:
h += control['middle'].pop()
@ -828,7 +841,7 @@ class UNetModel(nn.Module):
output_shape = hs[-1].shape
else:
output_shape = None
h = module(h, emb, context, transformer_options, output_shape)
h = forward_timestep_embed(module, h, emb, context, transformer_options, output_shape)
h = h.type(x.dtype)
if self.predict_codebook_ids:
return self.id_predictor(h)

View File

@ -201,6 +201,9 @@ def load_controlnet_gpu(control_models):
return
if vram_state == VRAMState.LOW_VRAM or vram_state == VRAMState.NO_VRAM:
for m in control_models:
if hasattr(m, 'set_lowvram'):
m.set_lowvram(True)
#don't load controlnets like this if low vram because they will be loaded right before running and unloaded right after
return

View File

@ -724,7 +724,7 @@ export class ComfyApp {
ctx.globalAlpha = 0.8;
ctx.beginPath();
if (shape == LiteGraph.BOX_SHAPE)
ctx.rect(-6, -6 + LiteGraph.NODE_TITLE_HEIGHT, 12 + size[0] + 1, 12 + size[1] + LiteGraph.NODE_TITLE_HEIGHT);
ctx.rect(-6, -6 - LiteGraph.NODE_TITLE_HEIGHT, 12 + size[0] + 1, 12 + size[1] + LiteGraph.NODE_TITLE_HEIGHT);
else if (shape == LiteGraph.ROUND_SHAPE || (shape == LiteGraph.CARD_SHAPE && node.flags.collapsed))
ctx.roundRect(
-6,
@ -736,12 +736,11 @@ export class ComfyApp {
else if (shape == LiteGraph.CARD_SHAPE)
ctx.roundRect(
-6,
-6 + LiteGraph.NODE_TITLE_HEIGHT,
-6 - LiteGraph.NODE_TITLE_HEIGHT,
12 + size[0] + 1,
12 + size[1] + LiteGraph.NODE_TITLE_HEIGHT,
this.round_radius * 2,
2
);
[this.round_radius * 2, this.round_radius * 2, 2, 2]
);
else if (shape == LiteGraph.CIRCLE_SHAPE)
ctx.arc(size[0] * 0.5, size[1] * 0.5, size[0] * 0.5 + 6, 0, Math.PI * 2);
ctx.strokeStyle = color;