mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 14:20:49 +08:00
fix imports and other basic problems
This commit is contained in:
parent
666f5b96f7
commit
d79d7a7e08
@ -4,6 +4,7 @@ import os
|
|||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class AppSettings():
|
class AppSettings():
|
||||||
def __init__(self, user_manager):
|
def __init__(self, user_manager):
|
||||||
@ -16,14 +17,14 @@ class AppSettings():
|
|||||||
"comfy.settings.json"
|
"comfy.settings.json"
|
||||||
)
|
)
|
||||||
except KeyError as e:
|
except KeyError as e:
|
||||||
logging.error("User settings not found.")
|
loggererror("User settings not found.")
|
||||||
raise web.HTTPUnauthorized() from e
|
raise web.HTTPUnauthorized() from e
|
||||||
if os.path.isfile(file):
|
if os.path.isfile(file):
|
||||||
try:
|
try:
|
||||||
with open(file) as f:
|
with open(file) as f:
|
||||||
return json.load(f)
|
return json.load(f)
|
||||||
except:
|
except:
|
||||||
logging.error(f"The user settings file is corrupted: {file}")
|
loggererror(f"The user settings file is corrupted: {file}")
|
||||||
return {}
|
return {}
|
||||||
else:
|
else:
|
||||||
return {}
|
return {}
|
||||||
|
|||||||
@ -22,7 +22,7 @@ default_user = "default"
|
|||||||
class FileInfo(TypedDict):
|
class FileInfo(TypedDict):
|
||||||
path: str
|
path: str
|
||||||
size: int
|
size: int
|
||||||
modified: int
|
modified: float
|
||||||
|
|
||||||
|
|
||||||
def get_file_info(path: str, relative_to: str) -> FileInfo:
|
def get_file_info(path: str, relative_to: str) -> FileInfo:
|
||||||
|
|||||||
@ -97,7 +97,7 @@ class CONDList(CONDRegular):
|
|||||||
def process_cond(self, batch_size, device, **kwargs):
|
def process_cond(self, batch_size, device, **kwargs):
|
||||||
out = []
|
out = []
|
||||||
for c in self.cond:
|
for c in self.cond:
|
||||||
out.append(comfy.utils.repeat_to_batch_size(c, batch_size).to(device))
|
out.append(utils.repeat_to_batch_size(c, batch_size).to(device))
|
||||||
|
|
||||||
return self._copy_with(out)
|
return self._copy_with(out)
|
||||||
|
|
||||||
|
|||||||
0
comfy/ldm/ace/__init__.py
Normal file
0
comfy/ldm/ace/__init__.py
Normal file
@ -18,8 +18,8 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
import comfy.model_management
|
from ... import model_management
|
||||||
from comfy.ldm.modules.attention import optimized_attention
|
from ..modules.attention import optimized_attention
|
||||||
|
|
||||||
class Attention(nn.Module):
|
class Attention(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -704,7 +704,7 @@ class LinearTransformerBlock(nn.Module):
|
|||||||
# step 1: AdaLN single
|
# step 1: AdaLN single
|
||||||
if self.use_adaln_single:
|
if self.use_adaln_single:
|
||||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
||||||
comfy.model_management.cast_to(self.scale_shift_table[None], dtype=temb.dtype, device=temb.device) + temb.reshape(N, 6, -1)
|
model_management.cast_to(self.scale_shift_table[None], dtype=temb.dtype, device=temb.device) + temb.reshape(N, 6, -1)
|
||||||
).chunk(6, dim=1)
|
).chunk(6, dim=1)
|
||||||
|
|
||||||
norm_hidden_states = self.norm1(hidden_states)
|
norm_hidden_states = self.norm1(hidden_states)
|
||||||
|
|||||||
@ -4,7 +4,7 @@ import math
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
import comfy.model_management
|
from ... import model_management
|
||||||
|
|
||||||
class ConvolutionModule(nn.Module):
|
class ConvolutionModule(nn.Module):
|
||||||
"""ConvolutionModule in Conformer model."""
|
"""ConvolutionModule in Conformer model."""
|
||||||
@ -423,9 +423,9 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
|
|||||||
p = p.transpose(1, 2) # (batch, head, time1, d_k)
|
p = p.transpose(1, 2) # (batch, head, time1, d_k)
|
||||||
|
|
||||||
# (batch, head, time1, d_k)
|
# (batch, head, time1, d_k)
|
||||||
q_with_bias_u = (q + comfy.model_management.cast_to(self.pos_bias_u, dtype=q.dtype, device=q.device)).transpose(1, 2)
|
q_with_bias_u = (q + model_management.cast_to(self.pos_bias_u, dtype=q.dtype, device=q.device)).transpose(1, 2)
|
||||||
# (batch, head, time1, d_k)
|
# (batch, head, time1, d_k)
|
||||||
q_with_bias_v = (q + comfy.model_management.cast_to(self.pos_bias_v, dtype=q.dtype, device=q.device)).transpose(1, 2)
|
q_with_bias_v = (q + model_management.cast_to(self.pos_bias_v, dtype=q.dtype, device=q.device)).transpose(1, 2)
|
||||||
|
|
||||||
# compute attention score
|
# compute attention score
|
||||||
# first compute matrix a and matrix c
|
# first compute matrix a and matrix c
|
||||||
|
|||||||
@ -18,9 +18,9 @@ from typing import Optional, List, Union
|
|||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
import comfy.model_management
|
from ... import model_management
|
||||||
|
|
||||||
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
|
from ..lightricks.model import TimestepEmbedding, Timesteps
|
||||||
from .attention import LinearTransformerBlock, t2i_modulate
|
from .attention import LinearTransformerBlock, t2i_modulate
|
||||||
from .lyric_encoder import ConformerEncoder as LyricEncoder
|
from .lyric_encoder import ConformerEncoder as LyricEncoder
|
||||||
|
|
||||||
@ -104,7 +104,7 @@ class T2IFinalLayer(nn.Module):
|
|||||||
return output
|
return output
|
||||||
|
|
||||||
def forward(self, x, t, output_length):
|
def forward(self, x, t, output_length):
|
||||||
shift, scale = (comfy.model_management.cast_to(self.scale_shift_table[None], device=t.device, dtype=t.dtype) + t[:, None]).chunk(2, dim=1)
|
shift, scale = (model_management.cast_to(self.scale_shift_table[None], device=t.device, dtype=t.dtype) + t[:, None]).chunk(2, dim=1)
|
||||||
x = t2i_modulate(self.norm_final(x), shift, scale)
|
x = t2i_modulate(self.norm_final(x), shift, scale)
|
||||||
x = self.linear(x)
|
x = self.linear(x)
|
||||||
# unpatchify
|
# unpatchify
|
||||||
|
|||||||
0
comfy/ldm/ace/vae/__init__.py
Normal file
0
comfy/ldm/ace/vae/__init__.py
Normal file
@ -4,9 +4,8 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from typing import Tuple, Union
|
from typing import Tuple, Union
|
||||||
|
|
||||||
import comfy.model_management
|
from .... import model_management
|
||||||
import comfy.ops
|
from ....ops import disable_weight_init as ops
|
||||||
ops = comfy.ops.disable_weight_init
|
|
||||||
|
|
||||||
|
|
||||||
class RMSNorm(ops.RMSNorm):
|
class RMSNorm(ops.RMSNorm):
|
||||||
@ -19,7 +18,7 @@ class RMSNorm(ops.RMSNorm):
|
|||||||
x = super().forward(x)
|
x = super().forward(x)
|
||||||
if self.elementwise_affine:
|
if self.elementwise_affine:
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
x = x + comfy.model_management.cast_to(self.bias, dtype=x.dtype, device=x.device)
|
x = x + model_management.cast_to(self.bias, dtype=x.dtype, device=x.device)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -8,7 +8,7 @@ try:
|
|||||||
except:
|
except:
|
||||||
logging.warning("torchaudio missing, ACE model will be broken")
|
logging.warning("torchaudio missing, ACE model will be broken")
|
||||||
|
|
||||||
import comfy.model_management
|
from .... import model_management
|
||||||
|
|
||||||
class LinearSpectrogram(nn.Module):
|
class LinearSpectrogram(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -47,7 +47,7 @@ class LinearSpectrogram(nn.Module):
|
|||||||
self.n_fft,
|
self.n_fft,
|
||||||
hop_length=self.hop_length,
|
hop_length=self.hop_length,
|
||||||
win_length=self.win_length,
|
win_length=self.win_length,
|
||||||
window=comfy.model_management.cast_to(self.window, dtype=torch.float32, device=y.device),
|
window=model_management.cast_to(self.window, dtype=torch.float32, device=y.device),
|
||||||
center=self.center,
|
center=self.center,
|
||||||
pad_mode="reflect",
|
pad_mode="reflect",
|
||||||
normalized=False,
|
normalized=False,
|
||||||
|
|||||||
@ -12,9 +12,8 @@ from torch.nn.utils.parametrize import remove_parametrizations as remove_weight_
|
|||||||
|
|
||||||
from .music_log_mel import LogMelSpectrogram
|
from .music_log_mel import LogMelSpectrogram
|
||||||
|
|
||||||
import comfy.model_management
|
from .... import model_management
|
||||||
import comfy.ops
|
from ....ops import disable_weight_init as ops
|
||||||
ops = comfy.ops.disable_weight_init
|
|
||||||
|
|
||||||
|
|
||||||
def drop_path(
|
def drop_path(
|
||||||
@ -77,13 +76,13 @@ class LayerNorm(nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
if self.data_format == "channels_last":
|
if self.data_format == "channels_last":
|
||||||
return F.layer_norm(
|
return F.layer_norm(
|
||||||
x, self.normalized_shape, comfy.model_management.cast_to(self.weight, dtype=x.dtype, device=x.device), comfy.model_management.cast_to(self.bias, dtype=x.dtype, device=x.device), self.eps
|
x, self.normalized_shape, model_management.cast_to(self.weight, dtype=x.dtype, device=x.device), model_management.cast_to(self.bias, dtype=x.dtype, device=x.device), self.eps
|
||||||
)
|
)
|
||||||
elif self.data_format == "channels_first":
|
elif self.data_format == "channels_first":
|
||||||
u = x.mean(1, keepdim=True)
|
u = x.mean(1, keepdim=True)
|
||||||
s = (x - u).pow(2).mean(1, keepdim=True)
|
s = (x - u).pow(2).mean(1, keepdim=True)
|
||||||
x = (x - u) / torch.sqrt(s + self.eps)
|
x = (x - u) / torch.sqrt(s + self.eps)
|
||||||
x = comfy.model_management.cast_to(self.weight[:, None], dtype=x.dtype, device=x.device) * x + comfy.model_management.cast_to(self.bias[:, None], dtype=x.dtype, device=x.device)
|
x = model_management.cast_to(self.weight[:, None], dtype=x.dtype, device=x.device) * x + model_management.cast_to(self.bias[:, None], dtype=x.dtype, device=x.device)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@ -145,7 +144,7 @@ class ConvNeXtBlock(nn.Module):
|
|||||||
x = self.pwconv2(x)
|
x = self.pwconv2(x)
|
||||||
|
|
||||||
if self.gamma is not None:
|
if self.gamma is not None:
|
||||||
x = comfy.model_management.cast_to(self.gamma, dtype=x.dtype, device=x.device) * x
|
x = model_management.cast_to(self.gamma, dtype=x.dtype, device=x.device) * x
|
||||||
|
|
||||||
x = x.permute(0, 2, 1) # (N, L, C) -> (N, C, L)
|
x = x.permute(0, 2, 1) # (N, L, C) -> (N, C, L)
|
||||||
x = self.drop_path(x)
|
x = self.drop_path(x)
|
||||||
|
|||||||
0
comfy/ldm/chroma/__init__.py
Normal file
0
comfy/ldm/chroma/__init__.py
Normal file
@ -1,14 +1,8 @@
|
|||||||
import torch
|
import torch
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
|
||||||
from comfy.ldm.flux.math import attention
|
from ..flux.math import attention
|
||||||
from comfy.ldm.flux.layers import (
|
from ..flux.layers import MLPEmbedder, RMSNorm, QKNorm, SelfAttention, ModulationOut
|
||||||
MLPEmbedder,
|
|
||||||
RMSNorm,
|
|
||||||
QKNorm,
|
|
||||||
SelfAttention,
|
|
||||||
ModulationOut,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -5,12 +5,9 @@ from dataclasses import dataclass
|
|||||||
import torch
|
import torch
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from einops import rearrange, repeat
|
from einops import rearrange, repeat
|
||||||
import comfy.ldm.common_dit
|
from ..common_dit import pad_to_patch_size
|
||||||
|
|
||||||
from comfy.ldm.flux.layers import (
|
from ..flux.layers import EmbedND, timestep_embedding
|
||||||
EmbedND,
|
|
||||||
timestep_embedding,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .layers import (
|
from .layers import (
|
||||||
DoubleStreamBlock,
|
DoubleStreamBlock,
|
||||||
@ -255,7 +252,7 @@ class Chroma(nn.Module):
|
|||||||
def forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs):
|
def forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs):
|
||||||
bs, c, h, w = x.shape
|
bs, c, h, w = x.shape
|
||||||
patch_size = 2
|
patch_size = 2
|
||||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
|
x = pad_to_patch_size(x, (patch_size, patch_size))
|
||||||
|
|
||||||
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)
|
||||||
|
|
||||||
|
|||||||
0
comfy/ldm/hidream/__init__.py
Normal file
0
comfy/ldm/hidream/__init__.py
Normal file
@ -5,15 +5,15 @@ import torch.nn as nn
|
|||||||
import einops
|
import einops
|
||||||
from einops import repeat
|
from einops import repeat
|
||||||
|
|
||||||
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
|
from ..lightricks.model import TimestepEmbedding, Timesteps
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from comfy.ldm.flux.math import apply_rope, rope
|
from ..flux.math import apply_rope, rope
|
||||||
from comfy.ldm.flux.layers import LastLayer
|
from ..flux.layers import LastLayer
|
||||||
|
|
||||||
from comfy.ldm.modules.attention import optimized_attention
|
from ..modules.attention import optimized_attention
|
||||||
import comfy.model_management
|
from ...model_management import cast_to
|
||||||
import comfy.ldm.common_dit
|
from ..common_dit import pad_to_patch_size
|
||||||
|
|
||||||
|
|
||||||
# Copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py
|
# Copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py
|
||||||
@ -261,7 +261,7 @@ class MoEGate(nn.Module):
|
|||||||
|
|
||||||
### compute gating score
|
### compute gating score
|
||||||
hidden_states = hidden_states.view(-1, h)
|
hidden_states = hidden_states.view(-1, h)
|
||||||
logits = F.linear(hidden_states, comfy.model_management.cast_to(self.weight, dtype=hidden_states.dtype, device=hidden_states.device), None)
|
logits = F.linear(hidden_states, cast_to(self.weight, dtype=hidden_states.dtype, device=hidden_states.device), None)
|
||||||
if self.scoring_func == 'softmax':
|
if self.scoring_func == 'softmax':
|
||||||
scores = logits.softmax(dim=-1)
|
scores = logits.softmax(dim=-1)
|
||||||
else:
|
else:
|
||||||
@ -706,7 +706,7 @@ class HiDreamImageTransformer2DModel(nn.Module):
|
|||||||
bs, c, h, w = x.shape
|
bs, c, h, w = x.shape
|
||||||
if image_cond is not None:
|
if image_cond is not None:
|
||||||
x = torch.cat([x, image_cond], dim=-1)
|
x = torch.cat([x, image_cond], dim=-1)
|
||||||
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
|
hidden_states = pad_to_patch_size(x, (self.patch_size, self.patch_size))
|
||||||
timesteps = t
|
timesteps = t
|
||||||
pooled_embeds = y
|
pooled_embeds = y
|
||||||
T5_encoder_hidden_states = context
|
T5_encoder_hidden_states = context
|
||||||
|
|||||||
0
comfy/ldm/wan/__init__.py
Normal file
0
comfy/ldm/wan/__init__.py
Normal file
@ -6,11 +6,11 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from einops import repeat
|
from einops import repeat
|
||||||
|
|
||||||
from comfy.ldm.modules.attention import optimized_attention
|
from ..modules.attention import optimized_attention
|
||||||
from comfy.ldm.flux.layers import EmbedND
|
from ..flux.layers import EmbedND
|
||||||
from comfy.ldm.flux.math import apply_rope
|
from ..flux.math import apply_rope
|
||||||
import comfy.ldm.common_dit
|
from ..common_dit import pad_to_patch_size
|
||||||
import comfy.model_management
|
from ...model_management import cast_to
|
||||||
|
|
||||||
|
|
||||||
def sinusoidal_embedding_1d(dim, position):
|
def sinusoidal_embedding_1d(dim, position):
|
||||||
@ -202,7 +202,7 @@ class WanAttentionBlock(nn.Module):
|
|||||||
"""
|
"""
|
||||||
# assert e.dtype == torch.float32
|
# assert e.dtype == torch.float32
|
||||||
|
|
||||||
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1)
|
e = (cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1)
|
||||||
# assert e[0].dtype == torch.float32
|
# assert e[0].dtype == torch.float32
|
||||||
|
|
||||||
# self-attention
|
# self-attention
|
||||||
@ -325,7 +325,7 @@ class Head(nn.Module):
|
|||||||
e(Tensor): Shape [B, C]
|
e(Tensor): Shape [B, C]
|
||||||
"""
|
"""
|
||||||
# assert e.dtype == torch.float32
|
# assert e.dtype == torch.float32
|
||||||
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e.unsqueeze(1)).chunk(2, dim=1)
|
e = (cast_to(self.modulation, dtype=x.dtype, device=x.device) + e.unsqueeze(1)).chunk(2, dim=1)
|
||||||
x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
|
x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@ -347,7 +347,7 @@ class MLPProj(torch.nn.Module):
|
|||||||
|
|
||||||
def forward(self, image_embeds):
|
def forward(self, image_embeds):
|
||||||
if self.emb_pos is not None:
|
if self.emb_pos is not None:
|
||||||
image_embeds = image_embeds[:, :self.emb_pos.shape[1]] + comfy.model_management.cast_to(self.emb_pos[:, :image_embeds.shape[1]], dtype=image_embeds.dtype, device=image_embeds.device)
|
image_embeds = image_embeds[:, :self.emb_pos.shape[1]] + cast_to(self.emb_pos[:, :image_embeds.shape[1]], dtype=image_embeds.dtype, device=image_embeds.device)
|
||||||
|
|
||||||
clip_extra_context_tokens = self.proj(image_embeds)
|
clip_extra_context_tokens = self.proj(image_embeds)
|
||||||
return clip_extra_context_tokens
|
return clip_extra_context_tokens
|
||||||
@ -541,7 +541,7 @@ class WanModel(torch.nn.Module):
|
|||||||
|
|
||||||
def forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs):
|
def forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs):
|
||||||
bs, c, t, h, w = x.shape
|
bs, c, t, h, w = x.shape
|
||||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
x = pad_to_patch_size(x, self.patch_size)
|
||||||
|
|
||||||
patch_size = self.patch_size
|
patch_size = self.patch_size
|
||||||
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
|
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
|
||||||
@ -549,7 +549,7 @@ class WanModel(torch.nn.Module):
|
|||||||
w_len = ((w + (patch_size[2] // 2)) // patch_size[2])
|
w_len = ((w + (patch_size[2] // 2)) // patch_size[2])
|
||||||
|
|
||||||
if time_dim_concat is not None:
|
if time_dim_concat is not None:
|
||||||
time_dim_concat = comfy.ldm.common_dit.pad_to_patch_size(time_dim_concat, self.patch_size)
|
time_dim_concat = pad_to_patch_size(time_dim_concat, self.patch_size)
|
||||||
x = torch.cat([x, time_dim_concat], dim=2)
|
x = torch.cat([x, time_dim_concat], dim=2)
|
||||||
t_len = ((x.shape[2] + (patch_size[0] // 2)) // patch_size[0])
|
t_len = ((x.shape[2] + (patch_size[0] // 2)) // patch_size[0])
|
||||||
|
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
import torch
|
import torch
|
||||||
import comfy.utils
|
from .utils import state_dict_prefix_replace
|
||||||
|
|
||||||
|
|
||||||
def convert_lora_bfl_control(sd): #BFL loras for Flux
|
def convert_lora_bfl_control(sd): #BFL loras for Flux
|
||||||
@ -13,7 +13,7 @@ def convert_lora_bfl_control(sd): #BFL loras for Flux
|
|||||||
|
|
||||||
|
|
||||||
def convert_lora_wan_fun(sd): #Wan Fun loras
|
def convert_lora_wan_fun(sd): #Wan Fun loras
|
||||||
return comfy.utils.state_dict_prefix_replace(sd, {"lora_unet__": "lora_unet_"})
|
return state_dict_prefix_replace(sd, {"lora_unet__": "lora_unet_"})
|
||||||
|
|
||||||
|
|
||||||
def convert_lora(sd):
|
def convert_lora(sd):
|
||||||
|
|||||||
@ -1,13 +1,16 @@
|
|||||||
from comfy import sd1_clip
|
from importlib.resources import files
|
||||||
from .spiece_tokenizer import SPieceTokenizer
|
|
||||||
import comfy.text_encoders.t5
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import torch
|
import torch
|
||||||
import logging
|
|
||||||
|
|
||||||
from tokenizers import Tokenizer
|
from tokenizers import Tokenizer
|
||||||
|
|
||||||
from .ace_text_cleaners import multilingual_cleaners, japanese_to_romaji
|
from .ace_text_cleaners import multilingual_cleaners, japanese_to_romaji
|
||||||
|
from .spiece_tokenizer import SPieceTokenizer
|
||||||
|
from .t5 import T5
|
||||||
|
from .. import sd1_clip
|
||||||
|
from ..component_model.files import get_path_as_dict
|
||||||
|
|
||||||
SUPPORT_LANGUAGES = {
|
SUPPORT_LANGUAGES = {
|
||||||
"en": 259, "de": 260, "fr": 262, "es": 284, "it": 285,
|
"en": 259, "de": 260, "fr": 262, "es": 284, "it": 285,
|
||||||
@ -18,11 +21,16 @@ SUPPORT_LANGUAGES = {
|
|||||||
|
|
||||||
structure_pattern = re.compile(r"\[.*?\]")
|
structure_pattern = re.compile(r"\[.*?\]")
|
||||||
|
|
||||||
DEFAULT_VOCAB_FILE = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "ace_lyrics_tokenizer"), "vocab.json")
|
|
||||||
|
def get_vocab_file() -> str:
|
||||||
|
return str(files(f"{__package__}.ace_lyrics_tokenizer") / "vocab.json")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class VoiceBpeTokenizer:
|
class VoiceBpeTokenizer:
|
||||||
def __init__(self, vocab_file=DEFAULT_VOCAB_FILE):
|
def __init__(self, vocab_file=None):
|
||||||
|
vocab_file = vocab_file or get_vocab_file()
|
||||||
self.tokenizer = None
|
self.tokenizer = None
|
||||||
if vocab_file is not None:
|
if vocab_file is not None:
|
||||||
self.tokenizer = Tokenizer.from_file(vocab_file)
|
self.tokenizer = Tokenizer.from_file(vocab_file)
|
||||||
@ -92,29 +100,40 @@ class VoiceBpeTokenizer:
|
|||||||
|
|
||||||
|
|
||||||
class UMT5BaseModel(sd1_clip.SDClipModel):
|
class UMT5BaseModel(sd1_clip.SDClipModel):
|
||||||
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
|
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options=None, textmodel_json_config=None):
|
||||||
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "umt5_config_base.json")
|
if model_options is None:
|
||||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=False, model_options=model_options)
|
model_options = {}
|
||||||
|
textmodel_json_config = get_path_as_dict(textmodel_json_config, "umt5_config_base.json", package=__package__)
|
||||||
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=T5, enable_attention_masks=True, zero_out_masked=False, model_options=model_options)
|
||||||
|
|
||||||
|
|
||||||
class UMT5BaseTokenizer(sd1_clip.SDTokenizer):
|
class UMT5BaseTokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data=None):
|
||||||
|
if tokenizer_data is None:
|
||||||
|
tokenizer_data = {}
|
||||||
tokenizer = tokenizer_data.get("spiece_model", None)
|
tokenizer = tokenizer_data.get("spiece_model", None)
|
||||||
super().__init__(tokenizer, pad_with_end=False, embedding_size=768, embedding_key='umt5base', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=0, tokenizer_data=tokenizer_data)
|
super().__init__(tokenizer, pad_with_end=False, embedding_size=768, embedding_key='umt5base', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=0, tokenizer_data=tokenizer_data)
|
||||||
|
|
||||||
def state_dict(self):
|
def state_dict(self):
|
||||||
return {"spiece_model": self.tokenizer.serialize_model()}
|
return {"spiece_model": self.tokenizer.serialize_model()}
|
||||||
|
|
||||||
|
|
||||||
class LyricsTokenizer(sd1_clip.SDTokenizer):
|
class LyricsTokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data=None):
|
||||||
tokenizer = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "ace_lyrics_tokenizer"), "vocab.json")
|
if tokenizer_data is None:
|
||||||
|
tokenizer_data = {}
|
||||||
|
tokenizer = get_vocab_file()
|
||||||
super().__init__(tokenizer, pad_with_end=False, embedding_size=1024, embedding_key='lyrics', tokenizer_class=VoiceBpeTokenizer, has_start_token=True, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=2, has_end_token=False, tokenizer_data=tokenizer_data)
|
super().__init__(tokenizer, pad_with_end=False, embedding_size=1024, embedding_key='lyrics', tokenizer_class=VoiceBpeTokenizer, has_start_token=True, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=2, has_end_token=False, tokenizer_data=tokenizer_data)
|
||||||
|
|
||||||
|
|
||||||
class AceT5Tokenizer:
|
class AceT5Tokenizer:
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data=None):
|
||||||
|
if tokenizer_data is None:
|
||||||
|
tokenizer_data = {}
|
||||||
self.voicebpe = LyricsTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
self.voicebpe = LyricsTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||||
self.umt5base = UMT5BaseTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
self.umt5base = UMT5BaseTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||||
|
|
||||||
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
def tokenize_with_weights(self, text: str, return_word_ids=False, **kwargs):
|
||||||
out = {}
|
out = {}
|
||||||
out["lyrics"] = self.voicebpe.tokenize_with_weights(kwargs.get("lyrics", ""), return_word_ids, **kwargs)
|
out["lyrics"] = self.voicebpe.tokenize_with_weights(kwargs.get("lyrics", ""), return_word_ids, **kwargs)
|
||||||
out["umt5base"] = self.umt5base.tokenize_with_weights(text, return_word_ids, **kwargs)
|
out["umt5base"] = self.umt5base.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||||
@ -126,9 +145,12 @@ class AceT5Tokenizer:
|
|||||||
def state_dict(self):
|
def state_dict(self):
|
||||||
return self.umt5base.state_dict()
|
return self.umt5base.state_dict()
|
||||||
|
|
||||||
|
|
||||||
class AceT5Model(torch.nn.Module):
|
class AceT5Model(torch.nn.Module):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
|
def __init__(self, device="cpu", dtype=None, model_options=None, **kwargs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
if model_options is None:
|
||||||
|
model_options = {}
|
||||||
self.umt5base = UMT5BaseModel(device=device, dtype=dtype, model_options=model_options)
|
self.umt5base = UMT5BaseModel(device=device, dtype=dtype, model_options=model_options)
|
||||||
self.dtypes = set()
|
self.dtypes = set()
|
||||||
if dtype is not None:
|
if dtype is not None:
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
from transformers import T5TokenizerFast
|
from transformers import T5TokenizerFast
|
||||||
|
|
||||||
from comfy import sd1_clip
|
from .. import sd1_clip
|
||||||
from . import sd3_clip
|
from . import sd3_clip
|
||||||
from ..component_model import files
|
from ..component_model import files
|
||||||
|
|
||||||
|
|||||||
@ -1,20 +1,24 @@
|
|||||||
|
import logging
|
||||||
|
import torch
|
||||||
|
|
||||||
from . import hunyuan_video
|
from . import hunyuan_video
|
||||||
from . import sd3_clip
|
from . import sd3_clip
|
||||||
from comfy import sd1_clip
|
from .. import sd1_clip
|
||||||
from comfy import sdxl_clip
|
from .. import sdxl_clip
|
||||||
import comfy.model_management
|
from ..model_management import intermediate_device, pick_weight_dtype
|
||||||
import torch
|
|
||||||
import logging
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class HiDreamTokenizer:
|
class HiDreamTokenizer:
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data=None):
|
||||||
|
if tokenizer_data is None:
|
||||||
|
tokenizer_data = {}
|
||||||
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||||
self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
|
||||||
self.t5xxl = sd3_clip.T5XXLTokenizer(embedding_directory=embedding_directory, min_length=128, max_length=128, tokenizer_data=tokenizer_data)
|
self.t5xxl = sd3_clip.T5XXLTokenizer(embedding_directory=embedding_directory, min_length=128, max_length=128, tokenizer_data=tokenizer_data)
|
||||||
self.llama = hunyuan_video.LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=128, pad_token=128009, tokenizer_data=tokenizer_data)
|
self.llama = hunyuan_video.LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=128, pad_token=128009, tokenizer_data=tokenizer_data)
|
||||||
|
|
||||||
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
|
def tokenize_with_weights(self, text: str, return_word_ids=False, **kwargs):
|
||||||
out = {}
|
out = {}
|
||||||
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids, **kwargs)
|
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||||
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs)
|
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs)
|
||||||
@ -31,8 +35,10 @@ class HiDreamTokenizer:
|
|||||||
|
|
||||||
|
|
||||||
class HiDreamTEModel(torch.nn.Module):
|
class HiDreamTEModel(torch.nn.Module):
|
||||||
def __init__(self, clip_l=True, clip_g=True, t5=True, llama=True, dtype_t5=None, dtype_llama=None, device="cpu", dtype=None, model_options={}):
|
def __init__(self, clip_l=True, clip_g=True, t5=True, llama=True, dtype_t5=None, dtype_llama=None, device="cpu", dtype=None, model_options=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
if model_options is None:
|
||||||
|
model_options = {}
|
||||||
self.dtypes = set()
|
self.dtypes = set()
|
||||||
if clip_l:
|
if clip_l:
|
||||||
self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=True, model_options=model_options)
|
self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=True, model_options=model_options)
|
||||||
@ -47,14 +53,14 @@ class HiDreamTEModel(torch.nn.Module):
|
|||||||
self.clip_g = None
|
self.clip_g = None
|
||||||
|
|
||||||
if t5:
|
if t5:
|
||||||
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
|
dtype_t5 = pick_weight_dtype(dtype_t5, dtype, device)
|
||||||
self.t5xxl = sd3_clip.T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options, attention_mask=True)
|
self.t5xxl = sd3_clip.T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options, attention_mask=True)
|
||||||
self.dtypes.add(dtype_t5)
|
self.dtypes.add(dtype_t5)
|
||||||
else:
|
else:
|
||||||
self.t5xxl = None
|
self.t5xxl = None
|
||||||
|
|
||||||
if llama:
|
if llama:
|
||||||
dtype_llama = comfy.model_management.pick_weight_dtype(dtype_llama, dtype, device)
|
dtype_llama = pick_weight_dtype(dtype_llama, dtype, device)
|
||||||
if "vocab_size" not in model_options:
|
if "vocab_size" not in model_options:
|
||||||
model_options["vocab_size"] = 128256
|
model_options["vocab_size"] = 128256
|
||||||
self.llama = hunyuan_video.LLAMAModel(device=device, dtype=dtype_llama, model_options=model_options, layer="all", layer_idx=None, special_tokens={"start": 128000, "pad": 128009})
|
self.llama = hunyuan_video.LLAMAModel(device=device, dtype=dtype_llama, model_options=model_options, layer="all", layer_idx=None, special_tokens={"start": 128000, "pad": 128009})
|
||||||
@ -62,7 +68,7 @@ class HiDreamTEModel(torch.nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.llama = None
|
self.llama = None
|
||||||
|
|
||||||
logging.debug("Created HiDream text encoder with: clip_l {}, clip_g {}, t5xxl {}:{}, llama {}:{}".format(clip_l, clip_g, t5, dtype_t5, llama, dtype_llama))
|
logger.debug("Created HiDream text encoder with: clip_l {}, clip_g {}, t5xxl {}:{}, llama {}:{}".format(clip_l, clip_g, t5, dtype_t5, llama, dtype_llama))
|
||||||
|
|
||||||
def set_clip_options(self, options):
|
def set_clip_options(self, options):
|
||||||
if self.clip_l is not None:
|
if self.clip_l is not None:
|
||||||
@ -97,12 +103,12 @@ class HiDreamTEModel(torch.nn.Module):
|
|||||||
if self.clip_l is not None:
|
if self.clip_l is not None:
|
||||||
lg_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
|
lg_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
|
||||||
else:
|
else:
|
||||||
l_pooled = torch.zeros((1, 768), device=comfy.model_management.intermediate_device())
|
l_pooled = torch.zeros((1, 768), device=intermediate_device())
|
||||||
|
|
||||||
if self.clip_g is not None:
|
if self.clip_g is not None:
|
||||||
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
|
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
|
||||||
else:
|
else:
|
||||||
g_pooled = torch.zeros((1, 1280), device=comfy.model_management.intermediate_device())
|
g_pooled = torch.zeros((1, 1280), device=intermediate_device())
|
||||||
|
|
||||||
pooled = torch.cat((l_pooled, g_pooled), dim=-1)
|
pooled = torch.cat((l_pooled, g_pooled), dim=-1)
|
||||||
|
|
||||||
@ -120,13 +126,13 @@ class HiDreamTEModel(torch.nn.Module):
|
|||||||
ll_out = None
|
ll_out = None
|
||||||
|
|
||||||
if t5_out is None:
|
if t5_out is None:
|
||||||
t5_out = torch.zeros((1, 128, 4096), device=comfy.model_management.intermediate_device())
|
t5_out = torch.zeros((1, 128, 4096), device=intermediate_device())
|
||||||
|
|
||||||
if ll_out is None:
|
if ll_out is None:
|
||||||
ll_out = torch.zeros((1, 32, 1, 4096), device=comfy.model_management.intermediate_device())
|
ll_out = torch.zeros((1, 32, 1, 4096), device=intermediate_device())
|
||||||
|
|
||||||
if pooled is None:
|
if pooled is None:
|
||||||
pooled = torch.zeros((1, 768 + 1280), device=comfy.model_management.intermediate_device())
|
pooled = torch.zeros((1, 768 + 1280), device=intermediate_device())
|
||||||
|
|
||||||
extra["conditioning_llama3"] = ll_out
|
extra["conditioning_llama3"] = ll_out
|
||||||
return t5_out, pooled, extra
|
return t5_out, pooled, extra
|
||||||
@ -144,7 +150,9 @@ class HiDreamTEModel(torch.nn.Module):
|
|||||||
|
|
||||||
def hidream_clip(clip_l=True, clip_g=True, t5=True, llama=True, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None):
|
def hidream_clip(clip_l=True, clip_g=True, t5=True, llama=True, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None):
|
||||||
class HiDreamTEModel_(HiDreamTEModel):
|
class HiDreamTEModel_(HiDreamTEModel):
|
||||||
def __init__(self, device="cpu", dtype=None, model_options={}):
|
def __init__(self, device="cpu", dtype=None, model_options=None):
|
||||||
|
if model_options is None:
|
||||||
|
model_options = {}
|
||||||
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
|
||||||
model_options = model_options.copy()
|
model_options = model_options.copy()
|
||||||
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
|
||||||
@ -152,4 +160,5 @@ def hidream_clip(clip_l=True, clip_g=True, t5=True, llama=True, dtype_t5=None, d
|
|||||||
model_options = model_options.copy()
|
model_options = model_options.copy()
|
||||||
model_options["llama_scaled_fp8"] = llama_scaled_fp8
|
model_options["llama_scaled_fp8"] = llama_scaled_fp8
|
||||||
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, llama=llama, dtype_t5=dtype_t5, dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)
|
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, llama=llama, dtype_t5=dtype_t5, dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)
|
||||||
|
|
||||||
return HiDreamTEModel_
|
return HiDreamTEModel_
|
||||||
|
|||||||
@ -16,7 +16,7 @@ class HyditBertModel(sd1_clip.SDClipModel):
|
|||||||
model_options = dict()
|
model_options = dict()
|
||||||
textmodel_json_config = get_path_as_dict(textmodel_json_config, "hydit_clip.json", package=__package__)
|
textmodel_json_config = get_path_as_dict(textmodel_json_config, "hydit_clip.json", package=__package__)
|
||||||
model_options = {**model_options, "model_name": "hydit_clip"}
|
model_options = {**model_options, "model_name": "hydit_clip"}
|
||||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 101, "end": 102, "pad": 0}, model_class=BertModel, enable_attention_masks=True, return_attention_masks=True)
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 101, "end": 102, "pad": 0}, model_class=BertModel, enable_attention_masks=True, return_attention_masks=True, model_options=model_options)
|
||||||
|
|
||||||
class HyditBertTokenizer(sd1_clip.SDTokenizer):
|
class HyditBertTokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
@ -31,7 +31,7 @@ class MT5XLModel(sd1_clip.SDClipModel):
|
|||||||
model_options = dict()
|
model_options = dict()
|
||||||
textmodel_json_config = get_path_as_dict(textmodel_json_config, "mt5_config_xl.json", package=__package__)
|
textmodel_json_config = get_path_as_dict(textmodel_json_config, "mt5_config_xl.json", package=__package__)
|
||||||
model_options = {**model_options, "model_name": "mt5xl"}
|
model_options = {**model_options, "model_name": "mt5xl"}
|
||||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=T5, enable_attention_masks=True, return_attention_masks=True)
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=T5, enable_attention_masks=True, return_attention_masks=True, model_options=model_options)
|
||||||
|
|
||||||
class MT5XLTokenizer(sd1_clip.SDTokenizer):
|
class MT5XLTokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, tokenizer_data=None, **kwargs):
|
def __init__(self, tokenizer_data=None, **kwargs):
|
||||||
|
|||||||
@ -1,8 +1,6 @@
|
|||||||
import os
|
|
||||||
|
|
||||||
import comfy.text_encoders.t5
|
|
||||||
from comfy import sd1_clip
|
|
||||||
from .spiece_tokenizer import SPieceTokenizer
|
from .spiece_tokenizer import SPieceTokenizer
|
||||||
|
from .t5 import T5
|
||||||
|
from .. import sd1_clip
|
||||||
from ..component_model.files import get_path_as_dict
|
from ..component_model.files import get_path_as_dict
|
||||||
|
|
||||||
|
|
||||||
@ -11,7 +9,7 @@ class UMT5XXlModel(sd1_clip.SDClipModel):
|
|||||||
if model_options is None:
|
if model_options is None:
|
||||||
model_options = {}
|
model_options = {}
|
||||||
textmodel_json_config = get_path_as_dict(textmodel_json_config, "umt5_config_xxl.json", package=__package__)
|
textmodel_json_config = get_path_as_dict(textmodel_json_config, "umt5_config_xxl.json", package=__package__)
|
||||||
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True, model_options=model_options)
|
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=T5, enable_attention_masks=True, zero_out_masked=True, model_options=model_options)
|
||||||
|
|
||||||
|
|
||||||
class UMT5XXlTokenizer(sd1_clip.SDTokenizer):
|
class UMT5XXlTokenizer(sd1_clip.SDTokenizer):
|
||||||
@ -19,7 +17,7 @@ class UMT5XXlTokenizer(sd1_clip.SDTokenizer):
|
|||||||
if tokenizer_data is None:
|
if tokenizer_data is None:
|
||||||
tokenizer_data = {}
|
tokenizer_data = {}
|
||||||
tokenizer = tokenizer_data.get("spiece_model", None)
|
tokenizer = tokenizer_data.get("spiece_model", None)
|
||||||
super().__init__(tokenizer, pad_with_end=False, embedding_size=4096, embedding_key='umt5xxl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=0, tokenizer_data=tokenizer_data)
|
super().__init__(tokenizer, pad_with_end=False, embedding_size=4096, embedding_key='umt5xxl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=0, tokenizer_data=tokenizer_data, embedding_directory=embedding_directory)
|
||||||
|
|
||||||
def state_dict(self):
|
def state_dict(self):
|
||||||
return {"spiece_model": self.tokenizer.serialize_model()}
|
return {"spiece_model": self.tokenizer.serialize_model()}
|
||||||
|
|||||||
@ -3,7 +3,7 @@ from typing import Optional
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
import comfy.model_management
|
from ..model_management import cast_to_device
|
||||||
|
|
||||||
|
|
||||||
class WeightAdapterBase:
|
class WeightAdapterBase:
|
||||||
@ -40,7 +40,7 @@ class WeightAdapterTrainBase(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function):
|
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function):
|
||||||
dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype)
|
dora_scale = cast_to_device(dora_scale, weight.device, intermediate_dtype)
|
||||||
lora_diff *= alpha
|
lora_diff *= alpha
|
||||||
weight_calc = weight + function(lora_diff).type(weight.dtype)
|
weight_calc = weight + function(lora_diff).type(weight.dtype)
|
||||||
|
|
||||||
|
|||||||
@ -2,9 +2,10 @@ import logging
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import comfy.model_management
|
from ..model_management import cast_to_device
|
||||||
from .base import WeightAdapterBase, weight_decompose
|
from .base import WeightAdapterBase, weight_decompose
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class BOFTAdapter(WeightAdapterBase):
|
class BOFTAdapter(WeightAdapterBase):
|
||||||
name = "boft"
|
name = "boft"
|
||||||
@ -62,9 +63,9 @@ class BOFTAdapter(WeightAdapterBase):
|
|||||||
alpha = v[2]
|
alpha = v[2]
|
||||||
dora_scale = v[3]
|
dora_scale = v[3]
|
||||||
|
|
||||||
blocks = comfy.model_management.cast_to_device(blocks, weight.device, intermediate_dtype)
|
blocks = cast_to_device(blocks, weight.device, intermediate_dtype)
|
||||||
if rescale is not None:
|
if rescale is not None:
|
||||||
rescale = comfy.model_management.cast_to_device(rescale, weight.device, intermediate_dtype)
|
rescale = cast_to_device(rescale, weight.device, intermediate_dtype)
|
||||||
|
|
||||||
boft_m, block_num, boft_b, *_ = blocks.shape
|
boft_m, block_num, boft_b, *_ = blocks.shape
|
||||||
|
|
||||||
@ -105,11 +106,11 @@ class BOFTAdapter(WeightAdapterBase):
|
|||||||
inp = inp * rescale
|
inp = inp * rescale
|
||||||
|
|
||||||
lora_diff = inp - org
|
lora_diff = inp - org
|
||||||
lora_diff = comfy.model_management.cast_to_device(lora_diff, weight.device, intermediate_dtype)
|
lora_diff = cast_to_device(lora_diff, weight.device, intermediate_dtype)
|
||||||
if dora_scale is not None:
|
if dora_scale is not None:
|
||||||
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
|
||||||
else:
|
else:
|
||||||
weight += function((strength * lora_diff).type(weight.dtype))
|
weight += function((strength * lora_diff).type(weight.dtype))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
logger.error("ERROR {} {} {}".format(self.name, key, e))
|
||||||
return weight
|
return weight
|
||||||
|
|||||||
@ -2,9 +2,10 @@ import logging
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import comfy.model_management
|
from ..model_management import cast_to_device
|
||||||
from .base import WeightAdapterBase, weight_decompose
|
from .base import WeightAdapterBase, weight_decompose
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class GLoRAAdapter(WeightAdapterBase):
|
class GLoRAAdapter(WeightAdapterBase):
|
||||||
name = "glora"
|
name = "glora"
|
||||||
@ -64,10 +65,10 @@ class GLoRAAdapter(WeightAdapterBase):
|
|||||||
old_glora = False
|
old_glora = False
|
||||||
rank = v[1].shape[0]
|
rank = v[1].shape[0]
|
||||||
|
|
||||||
a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype)
|
a1 = cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||||
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype)
|
a2 = cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||||
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype)
|
b1 = cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||||
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype)
|
b2 = cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype)
|
||||||
|
|
||||||
if v[4] is not None:
|
if v[4] is not None:
|
||||||
alpha = v[4] / rank
|
alpha = v[4] / rank
|
||||||
@ -89,5 +90,5 @@ class GLoRAAdapter(WeightAdapterBase):
|
|||||||
else:
|
else:
|
||||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
logger.error("ERROR {} {} {}".format(self.name, key, e))
|
||||||
return weight
|
return weight
|
||||||
|
|||||||
@ -2,9 +2,10 @@ import logging
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import comfy.model_management
|
from ..model_management import cast_to_device
|
||||||
from .base import WeightAdapterBase, weight_decompose
|
from .base import WeightAdapterBase, weight_decompose
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class LoHaAdapter(WeightAdapterBase):
|
class LoHaAdapter(WeightAdapterBase):
|
||||||
name = "loha"
|
name = "loha"
|
||||||
@ -75,19 +76,19 @@ class LoHaAdapter(WeightAdapterBase):
|
|||||||
t1 = v[5]
|
t1 = v[5]
|
||||||
t2 = v[6]
|
t2 = v[6]
|
||||||
m1 = torch.einsum('i j k l, j r, i p -> p r k l',
|
m1 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||||
comfy.model_management.cast_to_device(t1, weight.device, intermediate_dtype),
|
cast_to_device(t1, weight.device, intermediate_dtype),
|
||||||
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype),
|
cast_to_device(w1b, weight.device, intermediate_dtype),
|
||||||
comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype))
|
cast_to_device(w1a, weight.device, intermediate_dtype))
|
||||||
|
|
||||||
m2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
m2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||||
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
|
cast_to_device(t2, weight.device, intermediate_dtype),
|
||||||
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype),
|
cast_to_device(w2b, weight.device, intermediate_dtype),
|
||||||
comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype))
|
cast_to_device(w2a, weight.device, intermediate_dtype))
|
||||||
else:
|
else:
|
||||||
m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype),
|
m1 = torch.mm(cast_to_device(w1a, weight.device, intermediate_dtype),
|
||||||
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype))
|
cast_to_device(w1b, weight.device, intermediate_dtype))
|
||||||
m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype),
|
m2 = torch.mm(cast_to_device(w2a, weight.device, intermediate_dtype),
|
||||||
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype))
|
cast_to_device(w2b, weight.device, intermediate_dtype))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
lora_diff = (m1 * m2).reshape(weight.shape)
|
lora_diff = (m1 * m2).reshape(weight.shape)
|
||||||
@ -96,5 +97,5 @@ class LoHaAdapter(WeightAdapterBase):
|
|||||||
else:
|
else:
|
||||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
logger.error("ERROR {} {} {}".format(self.name, key, e))
|
||||||
return weight
|
return weight
|
||||||
|
|||||||
@ -2,9 +2,10 @@ import logging
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import comfy.model_management
|
from ..model_management import cast_to_device
|
||||||
from .base import WeightAdapterBase, weight_decompose
|
from .base import WeightAdapterBase, weight_decompose
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class LoKrAdapter(WeightAdapterBase):
|
class LoKrAdapter(WeightAdapterBase):
|
||||||
name = "lokr"
|
name = "lokr"
|
||||||
@ -97,23 +98,23 @@ class LoKrAdapter(WeightAdapterBase):
|
|||||||
|
|
||||||
if w1 is None:
|
if w1 is None:
|
||||||
dim = w1_b.shape[0]
|
dim = w1_b.shape[0]
|
||||||
w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, intermediate_dtype),
|
w1 = torch.mm(cast_to_device(w1_a, weight.device, intermediate_dtype),
|
||||||
comfy.model_management.cast_to_device(w1_b, weight.device, intermediate_dtype))
|
cast_to_device(w1_b, weight.device, intermediate_dtype))
|
||||||
else:
|
else:
|
||||||
w1 = comfy.model_management.cast_to_device(w1, weight.device, intermediate_dtype)
|
w1 = cast_to_device(w1, weight.device, intermediate_dtype)
|
||||||
|
|
||||||
if w2 is None:
|
if w2 is None:
|
||||||
dim = w2_b.shape[0]
|
dim = w2_b.shape[0]
|
||||||
if t2 is None:
|
if t2 is None:
|
||||||
w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype),
|
w2 = torch.mm(cast_to_device(w2_a, weight.device, intermediate_dtype),
|
||||||
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype))
|
cast_to_device(w2_b, weight.device, intermediate_dtype))
|
||||||
else:
|
else:
|
||||||
w2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
w2 = torch.einsum('i j k l, j r, i p -> p r k l',
|
||||||
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
|
cast_to_device(t2, weight.device, intermediate_dtype),
|
||||||
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype),
|
cast_to_device(w2_b, weight.device, intermediate_dtype),
|
||||||
comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype))
|
cast_to_device(w2_a, weight.device, intermediate_dtype))
|
||||||
else:
|
else:
|
||||||
w2 = comfy.model_management.cast_to_device(w2, weight.device, intermediate_dtype)
|
w2 = cast_to_device(w2, weight.device, intermediate_dtype)
|
||||||
|
|
||||||
if len(w2.shape) == 4:
|
if len(w2.shape) == 4:
|
||||||
w1 = w1.unsqueeze(2).unsqueeze(2)
|
w1 = w1.unsqueeze(2).unsqueeze(2)
|
||||||
@ -129,5 +130,5 @@ class LoKrAdapter(WeightAdapterBase):
|
|||||||
else:
|
else:
|
||||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
logger.error("ERROR {} {} {}".format(self.name, key, e))
|
||||||
return weight
|
return weight
|
||||||
|
|||||||
@ -2,9 +2,10 @@ import logging
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import comfy.model_management
|
from ..model_management import cast_to_device
|
||||||
from .base import WeightAdapterBase, weight_decompose, pad_tensor_to_shape
|
from .base import WeightAdapterBase, weight_decompose, pad_tensor_to_shape
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class LoRAAdapter(WeightAdapterBase):
|
class LoRAAdapter(WeightAdapterBase):
|
||||||
name = "lora"
|
name = "lora"
|
||||||
@ -90,10 +91,10 @@ class LoRAAdapter(WeightAdapterBase):
|
|||||||
original_weight=None,
|
original_weight=None,
|
||||||
):
|
):
|
||||||
v = self.weights
|
v = self.weights
|
||||||
mat1 = comfy.model_management.cast_to_device(
|
mat1 = cast_to_device(
|
||||||
v[0], weight.device, intermediate_dtype
|
v[0], weight.device, intermediate_dtype
|
||||||
)
|
)
|
||||||
mat2 = comfy.model_management.cast_to_device(
|
mat2 = cast_to_device(
|
||||||
v[1], weight.device, intermediate_dtype
|
v[1], weight.device, intermediate_dtype
|
||||||
)
|
)
|
||||||
dora_scale = v[4]
|
dora_scale = v[4]
|
||||||
@ -109,7 +110,7 @@ class LoRAAdapter(WeightAdapterBase):
|
|||||||
|
|
||||||
if v[3] is not None:
|
if v[3] is not None:
|
||||||
# locon mid weights, hopefully the math is fine because I didn't properly test it
|
# locon mid weights, hopefully the math is fine because I didn't properly test it
|
||||||
mat3 = comfy.model_management.cast_to_device(
|
mat3 = cast_to_device(
|
||||||
v[3], weight.device, intermediate_dtype
|
v[3], weight.device, intermediate_dtype
|
||||||
)
|
)
|
||||||
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
|
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
|
||||||
@ -138,5 +139,5 @@ class LoRAAdapter(WeightAdapterBase):
|
|||||||
else:
|
else:
|
||||||
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
logger.error("ERROR {} {} {}".format(self.name, key, e))
|
||||||
return weight
|
return weight
|
||||||
|
|||||||
@ -2,9 +2,10 @@ import logging
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import comfy.model_management
|
from ..model_management import cast_to_device
|
||||||
from .base import WeightAdapterBase, weight_decompose
|
from .base import WeightAdapterBase, weight_decompose
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
class OFTAdapter(WeightAdapterBase):
|
class OFTAdapter(WeightAdapterBase):
|
||||||
name = "oft"
|
name = "oft"
|
||||||
@ -62,9 +63,9 @@ class OFTAdapter(WeightAdapterBase):
|
|||||||
alpha = v[2]
|
alpha = v[2]
|
||||||
dora_scale = v[3]
|
dora_scale = v[3]
|
||||||
|
|
||||||
blocks = comfy.model_management.cast_to_device(blocks, weight.device, intermediate_dtype)
|
blocks = cast_to_device(blocks, weight.device, intermediate_dtype)
|
||||||
if rescale is not None:
|
if rescale is not None:
|
||||||
rescale = comfy.model_management.cast_to_device(rescale, weight.device, intermediate_dtype)
|
rescale = cast_to_device(rescale, weight.device, intermediate_dtype)
|
||||||
|
|
||||||
block_num, block_size, *_ = blocks.shape
|
block_num, block_size, *_ = blocks.shape
|
||||||
|
|
||||||
@ -92,5 +93,5 @@ class OFTAdapter(WeightAdapterBase):
|
|||||||
else:
|
else:
|
||||||
weight += function((strength * lora_diff).type(weight.dtype))
|
weight += function((strength * lora_diff).type(weight.dtype))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error("ERROR {} {} {}".format(self.name, key, e))
|
logger.error("ERROR {} {} {}".format(self.name, key, e))
|
||||||
return weight
|
return weight
|
||||||
|
|||||||
0
comfy_api/__init__.py
Normal file
0
comfy_api/__init__.py
Normal file
0
comfy_config/__init__.py
Normal file
0
comfy_config/__init__.py
Normal file
@ -99,6 +99,7 @@ dependencies = [
|
|||||||
"colour",
|
"colour",
|
||||||
"av>=14.2.0",
|
"av>=14.2.0",
|
||||||
"pydantic~=2.0",
|
"pydantic~=2.0",
|
||||||
|
"pydantic-settings~=2.0",
|
||||||
"typer",
|
"typer",
|
||||||
"ijson",
|
"ijson",
|
||||||
"scikit-learn>=1.4.1",
|
"scikit-learn>=1.4.1",
|
||||||
@ -249,4 +250,4 @@ exclude = ["*.ipynb"]
|
|||||||
allow-direct-references = true
|
allow-direct-references = true
|
||||||
|
|
||||||
[tool.hatch.build.targets.wheel]
|
[tool.hatch.build.targets.wheel]
|
||||||
packages = ["comfy/", "comfy_extras/"]
|
packages = ["comfy/", "comfy_extras/", "comfy_api/", "comfy_api_nodes/", "comfy_config/"]
|
||||||
Loading…
Reference in New Issue
Block a user