Merge upstream

This commit is contained in:
doctorpangloss 2024-02-29 20:48:27 -08:00
commit 915f2da874
6 changed files with 82 additions and 43 deletions

View File

@ -240,9 +240,9 @@ def convert_text_enc_state_dict_v20(text_enc_dict, prefix=""):
text_proj = "transformer.text_projection.weight"
if k.endswith(text_proj):
new_state_dict[k.replace(text_proj, "text_projection")] = v.transpose(0, 1).contiguous()
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k)
new_state_dict[relabelled_key] = v
else:
relabelled_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], k)
new_state_dict[relabelled_key] = v
for k_pre, tensors in capture_qkv_weight.items():
if None in tensors:

View File

@ -484,7 +484,6 @@ class UNetModel(nn.Module):
self.predict_codebook_ids = n_embed is not None
self.default_num_video_frames = None
self.default_image_only_indicator = None
time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential(
@ -708,27 +707,30 @@ class UNetModel(nn.Module):
device=device,
operations=operations
)]
if transformer_depth_middle >= 0:
mid_block += [get_attention_layer( # always uses a self-attn
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
disable_self_attn=disable_middle_self_attn, use_checkpoint=use_checkpoint
),
get_resblock(
merge_factor=merge_factor,
merge_strategy=merge_strategy,
video_kernel_size=video_kernel_size,
ch=ch,
time_embed_dim=time_embed_dim,
dropout=dropout,
out_channels=None,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
dtype=self.dtype,
device=device,
operations=operations
)]
self.middle_block = TimestepEmbedSequential(*mid_block)
self.middle_block = None
if transformer_depth_middle >= -1:
if transformer_depth_middle >= 0:
mid_block += [get_attention_layer( # always uses a self-attn
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
disable_self_attn=disable_middle_self_attn, use_checkpoint=use_checkpoint
),
get_resblock(
merge_factor=merge_factor,
merge_strategy=merge_strategy,
video_kernel_size=video_kernel_size,
ch=ch,
time_embed_dim=time_embed_dim,
dropout=dropout,
out_channels=None,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
dtype=self.dtype,
device=device,
operations=operations
)]
self.middle_block = TimestepEmbedSequential(*mid_block)
self._feature_size += ch
self.output_blocks = nn.ModuleList([])
@ -827,7 +829,7 @@ class UNetModel(nn.Module):
transformer_patches = transformer_options.get("patches", {})
num_video_frames = kwargs.get("num_video_frames", self.default_num_video_frames)
image_only_indicator = kwargs.get("image_only_indicator", self.default_image_only_indicator)
image_only_indicator = kwargs.get("image_only_indicator", None)
time_context = kwargs.get("time_context", None)
assert (y is not None) == (
@ -858,7 +860,8 @@ class UNetModel(nn.Module):
h = p(h, transformer_options)
transformer_options["block"] = ("middle", 0)
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
if self.middle_block is not None:
h = forward_timestep_embed(self.middle_block, h, emb, context, transformer_options, time_context=time_context, num_video_frames=num_video_frames, image_only_indicator=image_only_indicator)
h = apply_control(h, control, 'middle')

View File

@ -46,23 +46,25 @@ class AlphaBlender(nn.Module):
else:
raise ValueError(f"unknown merge strategy {self.merge_strategy}")
def get_alpha(self, image_only_indicator: torch.Tensor) -> torch.Tensor:
def get_alpha(self, image_only_indicator: torch.Tensor, device) -> torch.Tensor:
# skip_time_mix = rearrange(repeat(skip_time_mix, 'b -> (b t) () () ()', t=t), '(b t) 1 ... -> b 1 t ...', t=t)
if self.merge_strategy == "fixed":
# make shape compatible
# alpha = repeat(self.mix_factor, '1 -> b () t () ()', t=t, b=bs)
alpha = self.mix_factor.to(image_only_indicator.device)
alpha = self.mix_factor.to(device)
elif self.merge_strategy == "learned":
alpha = torch.sigmoid(self.mix_factor.to(image_only_indicator.device))
alpha = torch.sigmoid(self.mix_factor.to(device))
# make shape compatible
# alpha = repeat(alpha, '1 -> s () ()', s = t * bs)
elif self.merge_strategy == "learned_with_images":
assert image_only_indicator is not None, "need image_only_indicator ..."
alpha = torch.where(
image_only_indicator.bool(),
torch.ones(1, 1, device=image_only_indicator.device),
rearrange(torch.sigmoid(self.mix_factor.to(image_only_indicator.device)), "... -> ... 1"),
)
if image_only_indicator is None:
alpha = rearrange(torch.sigmoid(self.mix_factor.to(device)), "... -> ... 1")
else:
alpha = torch.where(
image_only_indicator.bool(),
torch.ones(1, 1, device=image_only_indicator.device),
rearrange(torch.sigmoid(self.mix_factor.to(image_only_indicator.device)), "... -> ... 1"),
)
alpha = rearrange(alpha, self.rearrange_pattern)
# make shape compatible
# alpha = repeat(alpha, '1 -> s () ()', s = t * bs)
@ -76,7 +78,7 @@ class AlphaBlender(nn.Module):
x_temporal,
image_only_indicator=None,
) -> torch.Tensor:
alpha = self.get_alpha(image_only_indicator)
alpha = self.get_alpha(image_only_indicator, x_spatial.device)
x = (
alpha.to(x_spatial.dtype) * x_spatial
+ (1.0 - alpha).to(x_spatial.dtype) * x_temporal

View File

@ -371,7 +371,6 @@ class SVD_img2vid(BaseModel):
if "time_conditioning" in kwargs:
out["time_context"] = conds.CONDCrossAttn(kwargs["time_conditioning"])
out['image_only_indicator'] = conds.CONDConstant(torch.zeros((1,), device=device))
out['num_video_frames'] = conds.CONDConstant(noise.shape[0])
return out

View File

@ -151,8 +151,10 @@ def detect_unet_config(state_dict, key_prefix):
channel_mult.append(last_channel_mult)
if "{}middle_block.1.proj_in.weight".format(key_prefix) in state_dict_keys:
transformer_depth_middle = count_blocks(state_dict_keys, '{}middle_block.1.transformer_blocks.'.format(key_prefix) + '{}')
else:
elif "{}middle_block.0.in_layers.0.weight".format(key_prefix) in state_dict_keys:
transformer_depth_middle = -1
else:
transformer_depth_middle = -2
unet_config["in_channels"] = in_channels
unet_config["out_channels"] = out_channels
@ -242,6 +244,7 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
down_blocks = count_blocks(state_dict, "down_blocks.{}")
for i in range(down_blocks):
attn_blocks = count_blocks(state_dict, "down_blocks.{}.attentions.".format(i) + '{}')
res_blocks = count_blocks(state_dict, "down_blocks.{}.resnets.".format(i) + '{}')
for ab in range(attn_blocks):
transformer_count = count_blocks(state_dict, "down_blocks.{}.attentions.{}.transformer_blocks.".format(i, ab) + '{}')
transformer_depth.append(transformer_count)
@ -250,8 +253,8 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
attn_res *= 2
if attn_blocks == 0:
transformer_depth.append(0)
transformer_depth.append(0)
for i in range(res_blocks):
transformer_depth.append(0)
match["transformer_depth"] = transformer_depth
@ -329,7 +332,19 @@ def unet_config_from_diffusers_unet(state_dict, dtype=None):
'channel_mult': [1, 2, 4], 'transformer_depth_middle': -1, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64,
'use_temporal_attention': False, 'use_temporal_resblock': False}
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega]
KOALA_700M = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': [1, 1, 1], 'transformer_depth': [0, 2, 5], 'transformer_depth_output': [0, 0, 2, 2, 5, 5],
'channel_mult': [1, 2, 4], 'transformer_depth_middle': -2, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64,
'use_temporal_attention': False, 'use_temporal_resblock': False}
KOALA_1B = {'use_checkpoint': False, 'image_size': 32, 'out_channels': 4, 'use_spatial_transformer': True, 'legacy': False,
'num_classes': 'sequential', 'adm_in_channels': 2816, 'dtype': dtype, 'in_channels': 4, 'model_channels': 320,
'num_res_blocks': [1, 1, 1], 'transformer_depth': [0, 2, 6], 'transformer_depth_output': [0, 0, 2, 2, 6, 6],
'channel_mult': [1, 2, 4], 'transformer_depth_middle': 6, 'use_linear_in_transformer': True, 'context_dim': 2048, 'num_head_channels': 64,
'use_temporal_attention': False, 'use_temporal_resblock': False}
supported_models = [SDXL, SDXL_refiner, SD21, SD15, SD21_uncliph, SD21_unclipl, SDXL_mid_cnet, SDXL_small_cnet, SDXL_diffusers_inpaint, SSD_1B, Segmind_Vega, KOALA_700M, KOALA_1B]
for unet_config in supported_models:
matches = True

View File

@ -234,6 +234,26 @@ class Segmind_Vega(SDXL):
"use_temporal_attention": False,
}
class KOALA_700M(SDXL):
unet_config = {
"model_channels": 320,
"use_linear_in_transformer": True,
"transformer_depth": [0, 2, 5],
"context_dim": 2048,
"adm_in_channels": 2816,
"use_temporal_attention": False,
}
class KOALA_1B(SDXL):
unet_config = {
"model_channels": 320,
"use_linear_in_transformer": True,
"transformer_depth": [0, 2, 6],
"context_dim": 2048,
"adm_in_channels": 2816,
"use_temporal_attention": False,
}
class SVD_img2vid(supported_models_base.BASE):
unet_config = {
"model_channels": 320,
@ -380,5 +400,5 @@ class Stable_Cascade_B(Stable_Cascade_C):
return out
models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B]
models = [Stable_Zero123, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B]
models += [SVD_img2vid]