fix imports and other basic problems

This commit is contained in:
doctorpangloss 2025-06-17 11:19:48 -07:00
parent 666f5b96f7
commit d79d7a7e08
35 changed files with 173 additions and 147 deletions

View File

@ -4,6 +4,7 @@ import os
from aiohttp import web
import logging
logger = logging.getLogger(__name__)
class AppSettings():
def __init__(self, user_manager):
@ -16,14 +17,14 @@ class AppSettings():
"comfy.settings.json"
)
except KeyError as e:
logging.error("User settings not found.")
loggererror("User settings not found.")
raise web.HTTPUnauthorized() from e
if os.path.isfile(file):
try:
with open(file) as f:
return json.load(f)
except:
logging.error(f"The user settings file is corrupted: {file}")
loggererror(f"The user settings file is corrupted: {file}")
return {}
else:
return {}

View File

@ -22,7 +22,7 @@ default_user = "default"
class FileInfo(TypedDict):
path: str
size: int
modified: int
modified: float
def get_file_info(path: str, relative_to: str) -> FileInfo:

View File

@ -97,7 +97,7 @@ class CONDList(CONDRegular):
def process_cond(self, batch_size, device, **kwargs):
out = []
for c in self.cond:
out.append(comfy.utils.repeat_to_batch_size(c, batch_size).to(device))
out.append(utils.repeat_to_batch_size(c, batch_size).to(device))
return self._copy_with(out)

View File

View File

@ -18,8 +18,8 @@ import torch
import torch.nn.functional as F
from torch import nn
import comfy.model_management
from comfy.ldm.modules.attention import optimized_attention
from ... import model_management
from ..modules.attention import optimized_attention
class Attention(nn.Module):
def __init__(
@ -704,7 +704,7 @@ class LinearTransformerBlock(nn.Module):
# step 1: AdaLN single
if self.use_adaln_single:
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
comfy.model_management.cast_to(self.scale_shift_table[None], dtype=temb.dtype, device=temb.device) + temb.reshape(N, 6, -1)
model_management.cast_to(self.scale_shift_table[None], dtype=temb.dtype, device=temb.device) + temb.reshape(N, 6, -1)
).chunk(6, dim=1)
norm_hidden_states = self.norm1(hidden_states)

View File

@ -4,7 +4,7 @@ import math
import torch
from torch import nn
import comfy.model_management
from ... import model_management
class ConvolutionModule(nn.Module):
"""ConvolutionModule in Conformer model."""
@ -423,9 +423,9 @@ class RelPositionMultiHeadedAttention(MultiHeadedAttention):
p = p.transpose(1, 2) # (batch, head, time1, d_k)
# (batch, head, time1, d_k)
q_with_bias_u = (q + comfy.model_management.cast_to(self.pos_bias_u, dtype=q.dtype, device=q.device)).transpose(1, 2)
q_with_bias_u = (q + model_management.cast_to(self.pos_bias_u, dtype=q.dtype, device=q.device)).transpose(1, 2)
# (batch, head, time1, d_k)
q_with_bias_v = (q + comfy.model_management.cast_to(self.pos_bias_v, dtype=q.dtype, device=q.device)).transpose(1, 2)
q_with_bias_v = (q + model_management.cast_to(self.pos_bias_v, dtype=q.dtype, device=q.device)).transpose(1, 2)
# compute attention score
# first compute matrix a and matrix c

View File

@ -18,9 +18,9 @@ from typing import Optional, List, Union
import torch
from torch import nn
import comfy.model_management
from ... import model_management
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
from ..lightricks.model import TimestepEmbedding, Timesteps
from .attention import LinearTransformerBlock, t2i_modulate
from .lyric_encoder import ConformerEncoder as LyricEncoder
@ -104,7 +104,7 @@ class T2IFinalLayer(nn.Module):
return output
def forward(self, x, t, output_length):
shift, scale = (comfy.model_management.cast_to(self.scale_shift_table[None], device=t.device, dtype=t.dtype) + t[:, None]).chunk(2, dim=1)
shift, scale = (model_management.cast_to(self.scale_shift_table[None], device=t.device, dtype=t.dtype) + t[:, None]).chunk(2, dim=1)
x = t2i_modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
# unpatchify

View File

View File

@ -4,9 +4,8 @@ import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Union
import comfy.model_management
import comfy.ops
ops = comfy.ops.disable_weight_init
from .... import model_management
from ....ops import disable_weight_init as ops
class RMSNorm(ops.RMSNorm):
@ -19,7 +18,7 @@ class RMSNorm(ops.RMSNorm):
x = super().forward(x)
if self.elementwise_affine:
if self.bias is not None:
x = x + comfy.model_management.cast_to(self.bias, dtype=x.dtype, device=x.device)
x = x + model_management.cast_to(self.bias, dtype=x.dtype, device=x.device)
return x

View File

@ -8,7 +8,7 @@ try:
except:
logging.warning("torchaudio missing, ACE model will be broken")
import comfy.model_management
from .... import model_management
class LinearSpectrogram(nn.Module):
def __init__(
@ -47,7 +47,7 @@ class LinearSpectrogram(nn.Module):
self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
window=comfy.model_management.cast_to(self.window, dtype=torch.float32, device=y.device),
window=model_management.cast_to(self.window, dtype=torch.float32, device=y.device),
center=self.center,
pad_mode="reflect",
normalized=False,

View File

@ -12,9 +12,8 @@ from torch.nn.utils.parametrize import remove_parametrizations as remove_weight_
from .music_log_mel import LogMelSpectrogram
import comfy.model_management
import comfy.ops
ops = comfy.ops.disable_weight_init
from .... import model_management
from ....ops import disable_weight_init as ops
def drop_path(
@ -77,13 +76,13 @@ class LayerNorm(nn.Module):
def forward(self, x):
if self.data_format == "channels_last":
return F.layer_norm(
x, self.normalized_shape, comfy.model_management.cast_to(self.weight, dtype=x.dtype, device=x.device), comfy.model_management.cast_to(self.bias, dtype=x.dtype, device=x.device), self.eps
x, self.normalized_shape, model_management.cast_to(self.weight, dtype=x.dtype, device=x.device), model_management.cast_to(self.bias, dtype=x.dtype, device=x.device), self.eps
)
elif self.data_format == "channels_first":
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = comfy.model_management.cast_to(self.weight[:, None], dtype=x.dtype, device=x.device) * x + comfy.model_management.cast_to(self.bias[:, None], dtype=x.dtype, device=x.device)
x = model_management.cast_to(self.weight[:, None], dtype=x.dtype, device=x.device) * x + model_management.cast_to(self.bias[:, None], dtype=x.dtype, device=x.device)
return x
@ -145,7 +144,7 @@ class ConvNeXtBlock(nn.Module):
x = self.pwconv2(x)
if self.gamma is not None:
x = comfy.model_management.cast_to(self.gamma, dtype=x.dtype, device=x.device) * x
x = model_management.cast_to(self.gamma, dtype=x.dtype, device=x.device) * x
x = x.permute(0, 2, 1) # (N, L, C) -> (N, C, L)
x = self.drop_path(x)

View File

View File

@ -1,14 +1,8 @@
import torch
from torch import Tensor, nn
from comfy.ldm.flux.math import attention
from comfy.ldm.flux.layers import (
MLPEmbedder,
RMSNorm,
QKNorm,
SelfAttention,
ModulationOut,
)
from ..flux.math import attention
from ..flux.layers import MLPEmbedder, RMSNorm, QKNorm, SelfAttention, ModulationOut

View File

@ -5,12 +5,9 @@ from dataclasses import dataclass
import torch
from torch import Tensor, nn
from einops import rearrange, repeat
import comfy.ldm.common_dit
from ..common_dit import pad_to_patch_size
from comfy.ldm.flux.layers import (
EmbedND,
timestep_embedding,
)
from ..flux.layers import EmbedND, timestep_embedding
from .layers import (
DoubleStreamBlock,
@ -255,7 +252,7 @@ class Chroma(nn.Module):
def forward(self, x, timestep, context, guidance, control=None, transformer_options={}, **kwargs):
bs, c, h, w = x.shape
patch_size = 2
x = comfy.ldm.common_dit.pad_to_patch_size(x, (patch_size, patch_size))
x = pad_to_patch_size(x, (patch_size, patch_size))
img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size)

View File

View File

@ -5,15 +5,15 @@ import torch.nn as nn
import einops
from einops import repeat
from comfy.ldm.lightricks.model import TimestepEmbedding, Timesteps
from ..lightricks.model import TimestepEmbedding, Timesteps
import torch.nn.functional as F
from comfy.ldm.flux.math import apply_rope, rope
from comfy.ldm.flux.layers import LastLayer
from ..flux.math import apply_rope, rope
from ..flux.layers import LastLayer
from comfy.ldm.modules.attention import optimized_attention
import comfy.model_management
import comfy.ldm.common_dit
from ..modules.attention import optimized_attention
from ...model_management import cast_to
from ..common_dit import pad_to_patch_size
# Copied from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py
@ -261,7 +261,7 @@ class MoEGate(nn.Module):
### compute gating score
hidden_states = hidden_states.view(-1, h)
logits = F.linear(hidden_states, comfy.model_management.cast_to(self.weight, dtype=hidden_states.dtype, device=hidden_states.device), None)
logits = F.linear(hidden_states, cast_to(self.weight, dtype=hidden_states.dtype, device=hidden_states.device), None)
if self.scoring_func == 'softmax':
scores = logits.softmax(dim=-1)
else:
@ -706,7 +706,7 @@ class HiDreamImageTransformer2DModel(nn.Module):
bs, c, h, w = x.shape
if image_cond is not None:
x = torch.cat([x, image_cond], dim=-1)
hidden_states = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
hidden_states = pad_to_patch_size(x, (self.patch_size, self.patch_size))
timesteps = t
pooled_embeds = y
T5_encoder_hidden_states = context

View File

View File

@ -6,11 +6,11 @@ import torch
import torch.nn as nn
from einops import repeat
from comfy.ldm.modules.attention import optimized_attention
from comfy.ldm.flux.layers import EmbedND
from comfy.ldm.flux.math import apply_rope
import comfy.ldm.common_dit
import comfy.model_management
from ..modules.attention import optimized_attention
from ..flux.layers import EmbedND
from ..flux.math import apply_rope
from ..common_dit import pad_to_patch_size
from ...model_management import cast_to
def sinusoidal_embedding_1d(dim, position):
@ -202,7 +202,7 @@ class WanAttentionBlock(nn.Module):
"""
# assert e.dtype == torch.float32
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1)
e = (cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1)
# assert e[0].dtype == torch.float32
# self-attention
@ -325,7 +325,7 @@ class Head(nn.Module):
e(Tensor): Shape [B, C]
"""
# assert e.dtype == torch.float32
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e.unsqueeze(1)).chunk(2, dim=1)
e = (cast_to(self.modulation, dtype=x.dtype, device=x.device) + e.unsqueeze(1)).chunk(2, dim=1)
x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
return x
@ -347,7 +347,7 @@ class MLPProj(torch.nn.Module):
def forward(self, image_embeds):
if self.emb_pos is not None:
image_embeds = image_embeds[:, :self.emb_pos.shape[1]] + comfy.model_management.cast_to(self.emb_pos[:, :image_embeds.shape[1]], dtype=image_embeds.dtype, device=image_embeds.device)
image_embeds = image_embeds[:, :self.emb_pos.shape[1]] + cast_to(self.emb_pos[:, :image_embeds.shape[1]], dtype=image_embeds.dtype, device=image_embeds.device)
clip_extra_context_tokens = self.proj(image_embeds)
return clip_extra_context_tokens
@ -541,7 +541,7 @@ class WanModel(torch.nn.Module):
def forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs):
bs, c, t, h, w = x.shape
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
x = pad_to_patch_size(x, self.patch_size)
patch_size = self.patch_size
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
@ -549,7 +549,7 @@ class WanModel(torch.nn.Module):
w_len = ((w + (patch_size[2] // 2)) // patch_size[2])
if time_dim_concat is not None:
time_dim_concat = comfy.ldm.common_dit.pad_to_patch_size(time_dim_concat, self.patch_size)
time_dim_concat = pad_to_patch_size(time_dim_concat, self.patch_size)
x = torch.cat([x, time_dim_concat], dim=2)
t_len = ((x.shape[2] + (patch_size[0] // 2)) // patch_size[0])

View File

@ -1,5 +1,5 @@
import torch
import comfy.utils
from .utils import state_dict_prefix_replace
def convert_lora_bfl_control(sd): #BFL loras for Flux
@ -13,7 +13,7 @@ def convert_lora_bfl_control(sd): #BFL loras for Flux
def convert_lora_wan_fun(sd): #Wan Fun loras
return comfy.utils.state_dict_prefix_replace(sd, {"lora_unet__": "lora_unet_"})
return state_dict_prefix_replace(sd, {"lora_unet__": "lora_unet_"})
def convert_lora(sd):

View File

@ -1,13 +1,16 @@
from comfy import sd1_clip
from .spiece_tokenizer import SPieceTokenizer
import comfy.text_encoders.t5
from importlib.resources import files
import logging
import os
import re
import torch
import logging
from tokenizers import Tokenizer
from .ace_text_cleaners import multilingual_cleaners, japanese_to_romaji
from .spiece_tokenizer import SPieceTokenizer
from .t5 import T5
from .. import sd1_clip
from ..component_model.files import get_path_as_dict
SUPPORT_LANGUAGES = {
"en": 259, "de": 260, "fr": 262, "es": 284, "it": 285,
@ -18,11 +21,16 @@ SUPPORT_LANGUAGES = {
structure_pattern = re.compile(r"\[.*?\]")
DEFAULT_VOCAB_FILE = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "ace_lyrics_tokenizer"), "vocab.json")
def get_vocab_file() -> str:
return str(files(f"{__package__}.ace_lyrics_tokenizer") / "vocab.json")
class VoiceBpeTokenizer:
def __init__(self, vocab_file=DEFAULT_VOCAB_FILE):
def __init__(self, vocab_file=None):
vocab_file = vocab_file or get_vocab_file()
self.tokenizer = None
if vocab_file is not None:
self.tokenizer = Tokenizer.from_file(vocab_file)
@ -92,29 +100,40 @@ class VoiceBpeTokenizer:
class UMT5BaseModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options={}):
textmodel_json_config = os.path.join(os.path.dirname(os.path.realpath(__file__)), "umt5_config_base.json")
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=False, model_options=model_options)
def __init__(self, device="cpu", layer="last", layer_idx=None, dtype=None, model_options=None, textmodel_json_config=None):
if model_options is None:
model_options = {}
textmodel_json_config = get_path_as_dict(textmodel_json_config, "umt5_config_base.json", package=__package__)
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=T5, enable_attention_masks=True, zero_out_masked=False, model_options=model_options)
class UMT5BaseTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
def __init__(self, embedding_directory=None, tokenizer_data=None):
if tokenizer_data is None:
tokenizer_data = {}
tokenizer = tokenizer_data.get("spiece_model", None)
super().__init__(tokenizer, pad_with_end=False, embedding_size=768, embedding_key='umt5base', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=0, tokenizer_data=tokenizer_data)
def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()}
class LyricsTokenizer(sd1_clip.SDTokenizer):
def __init__(self, embedding_directory=None, tokenizer_data={}):
tokenizer = os.path.join(os.path.join(os.path.dirname(os.path.realpath(__file__)), "ace_lyrics_tokenizer"), "vocab.json")
def __init__(self, embedding_directory=None, tokenizer_data=None):
if tokenizer_data is None:
tokenizer_data = {}
tokenizer = get_vocab_file()
super().__init__(tokenizer, pad_with_end=False, embedding_size=1024, embedding_key='lyrics', tokenizer_class=VoiceBpeTokenizer, has_start_token=True, pad_to_max_length=False, max_length=99999999, min_length=1, pad_token=2, has_end_token=False, tokenizer_data=tokenizer_data)
class AceT5Tokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}):
def __init__(self, embedding_directory=None, tokenizer_data=None):
if tokenizer_data is None:
tokenizer_data = {}
self.voicebpe = LyricsTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.umt5base = UMT5BaseTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
def tokenize_with_weights(self, text: str, return_word_ids=False, **kwargs):
out = {}
out["lyrics"] = self.voicebpe.tokenize_with_weights(kwargs.get("lyrics", ""), return_word_ids, **kwargs)
out["umt5base"] = self.umt5base.tokenize_with_weights(text, return_word_ids, **kwargs)
@ -126,9 +145,12 @@ class AceT5Tokenizer:
def state_dict(self):
return self.umt5base.state_dict()
class AceT5Model(torch.nn.Module):
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
def __init__(self, device="cpu", dtype=None, model_options=None, **kwargs):
super().__init__()
if model_options is None:
model_options = {}
self.umt5base = UMT5BaseModel(device=device, dtype=dtype, model_options=model_options)
self.dtypes = set()
if dtype is not None:

View File

@ -1,6 +1,6 @@
from transformers import T5TokenizerFast
from comfy import sd1_clip
from .. import sd1_clip
from . import sd3_clip
from ..component_model import files

View File

@ -1,20 +1,24 @@
import logging
import torch
from . import hunyuan_video
from . import sd3_clip
from comfy import sd1_clip
from comfy import sdxl_clip
import comfy.model_management
import torch
import logging
from .. import sd1_clip
from .. import sdxl_clip
from ..model_management import intermediate_device, pick_weight_dtype
logger = logging.getLogger(__name__)
class HiDreamTokenizer:
def __init__(self, embedding_directory=None, tokenizer_data={}):
def __init__(self, embedding_directory=None, tokenizer_data=None):
if tokenizer_data is None:
tokenizer_data = {}
self.clip_l = sd1_clip.SDTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.clip_g = sdxl_clip.SDXLClipGTokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
self.t5xxl = sd3_clip.T5XXLTokenizer(embedding_directory=embedding_directory, min_length=128, max_length=128, tokenizer_data=tokenizer_data)
self.llama = hunyuan_video.LLAMA3Tokenizer(embedding_directory=embedding_directory, min_length=128, pad_token=128009, tokenizer_data=tokenizer_data)
def tokenize_with_weights(self, text:str, return_word_ids=False, **kwargs):
def tokenize_with_weights(self, text: str, return_word_ids=False, **kwargs):
out = {}
out["g"] = self.clip_g.tokenize_with_weights(text, return_word_ids, **kwargs)
out["l"] = self.clip_l.tokenize_with_weights(text, return_word_ids, **kwargs)
@ -31,8 +35,10 @@ class HiDreamTokenizer:
class HiDreamTEModel(torch.nn.Module):
def __init__(self, clip_l=True, clip_g=True, t5=True, llama=True, dtype_t5=None, dtype_llama=None, device="cpu", dtype=None, model_options={}):
def __init__(self, clip_l=True, clip_g=True, t5=True, llama=True, dtype_t5=None, dtype_llama=None, device="cpu", dtype=None, model_options=None):
super().__init__()
if model_options is None:
model_options = {}
self.dtypes = set()
if clip_l:
self.clip_l = sd1_clip.SDClipModel(device=device, dtype=dtype, return_projected_pooled=True, model_options=model_options)
@ -47,14 +53,14 @@ class HiDreamTEModel(torch.nn.Module):
self.clip_g = None
if t5:
dtype_t5 = comfy.model_management.pick_weight_dtype(dtype_t5, dtype, device)
dtype_t5 = pick_weight_dtype(dtype_t5, dtype, device)
self.t5xxl = sd3_clip.T5XXLModel(device=device, dtype=dtype_t5, model_options=model_options, attention_mask=True)
self.dtypes.add(dtype_t5)
else:
self.t5xxl = None
if llama:
dtype_llama = comfy.model_management.pick_weight_dtype(dtype_llama, dtype, device)
dtype_llama = pick_weight_dtype(dtype_llama, dtype, device)
if "vocab_size" not in model_options:
model_options["vocab_size"] = 128256
self.llama = hunyuan_video.LLAMAModel(device=device, dtype=dtype_llama, model_options=model_options, layer="all", layer_idx=None, special_tokens={"start": 128000, "pad": 128009})
@ -62,7 +68,7 @@ class HiDreamTEModel(torch.nn.Module):
else:
self.llama = None
logging.debug("Created HiDream text encoder with: clip_l {}, clip_g {}, t5xxl {}:{}, llama {}:{}".format(clip_l, clip_g, t5, dtype_t5, llama, dtype_llama))
logger.debug("Created HiDream text encoder with: clip_l {}, clip_g {}, t5xxl {}:{}, llama {}:{}".format(clip_l, clip_g, t5, dtype_t5, llama, dtype_llama))
def set_clip_options(self, options):
if self.clip_l is not None:
@ -97,12 +103,12 @@ class HiDreamTEModel(torch.nn.Module):
if self.clip_l is not None:
lg_out, l_pooled = self.clip_l.encode_token_weights(token_weight_pairs_l)
else:
l_pooled = torch.zeros((1, 768), device=comfy.model_management.intermediate_device())
l_pooled = torch.zeros((1, 768), device=intermediate_device())
if self.clip_g is not None:
g_out, g_pooled = self.clip_g.encode_token_weights(token_weight_pairs_g)
else:
g_pooled = torch.zeros((1, 1280), device=comfy.model_management.intermediate_device())
g_pooled = torch.zeros((1, 1280), device=intermediate_device())
pooled = torch.cat((l_pooled, g_pooled), dim=-1)
@ -120,13 +126,13 @@ class HiDreamTEModel(torch.nn.Module):
ll_out = None
if t5_out is None:
t5_out = torch.zeros((1, 128, 4096), device=comfy.model_management.intermediate_device())
t5_out = torch.zeros((1, 128, 4096), device=intermediate_device())
if ll_out is None:
ll_out = torch.zeros((1, 32, 1, 4096), device=comfy.model_management.intermediate_device())
ll_out = torch.zeros((1, 32, 1, 4096), device=intermediate_device())
if pooled is None:
pooled = torch.zeros((1, 768 + 1280), device=comfy.model_management.intermediate_device())
pooled = torch.zeros((1, 768 + 1280), device=intermediate_device())
extra["conditioning_llama3"] = ll_out
return t5_out, pooled, extra
@ -144,7 +150,9 @@ class HiDreamTEModel(torch.nn.Module):
def hidream_clip(clip_l=True, clip_g=True, t5=True, llama=True, dtype_t5=None, dtype_llama=None, t5xxl_scaled_fp8=None, llama_scaled_fp8=None):
class HiDreamTEModel_(HiDreamTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
def __init__(self, device="cpu", dtype=None, model_options=None):
if model_options is None:
model_options = {}
if t5xxl_scaled_fp8 is not None and "t5xxl_scaled_fp8" not in model_options:
model_options = model_options.copy()
model_options["t5xxl_scaled_fp8"] = t5xxl_scaled_fp8
@ -152,4 +160,5 @@ def hidream_clip(clip_l=True, clip_g=True, t5=True, llama=True, dtype_t5=None, d
model_options = model_options.copy()
model_options["llama_scaled_fp8"] = llama_scaled_fp8
super().__init__(clip_l=clip_l, clip_g=clip_g, t5=t5, llama=llama, dtype_t5=dtype_t5, dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)
return HiDreamTEModel_

View File

@ -16,7 +16,7 @@ class HyditBertModel(sd1_clip.SDClipModel):
model_options = dict()
textmodel_json_config = get_path_as_dict(textmodel_json_config, "hydit_clip.json", package=__package__)
model_options = {**model_options, "model_name": "hydit_clip"}
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 101, "end": 102, "pad": 0}, model_class=BertModel, enable_attention_masks=True, return_attention_masks=True)
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 101, "end": 102, "pad": 0}, model_class=BertModel, enable_attention_masks=True, return_attention_masks=True, model_options=model_options)
class HyditBertTokenizer(sd1_clip.SDTokenizer):
def __init__(self, **kwargs):
@ -31,7 +31,7 @@ class MT5XLModel(sd1_clip.SDClipModel):
model_options = dict()
textmodel_json_config = get_path_as_dict(textmodel_json_config, "mt5_config_xl.json", package=__package__)
model_options = {**model_options, "model_name": "mt5xl"}
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=T5, enable_attention_masks=True, return_attention_masks=True)
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=T5, enable_attention_masks=True, return_attention_masks=True, model_options=model_options)
class MT5XLTokenizer(sd1_clip.SDTokenizer):
def __init__(self, tokenizer_data=None, **kwargs):

View File

@ -1,8 +1,6 @@
import os
import comfy.text_encoders.t5
from comfy import sd1_clip
from .spiece_tokenizer import SPieceTokenizer
from .t5 import T5
from .. import sd1_clip
from ..component_model.files import get_path_as_dict
@ -11,7 +9,7 @@ class UMT5XXlModel(sd1_clip.SDClipModel):
if model_options is None:
model_options = {}
textmodel_json_config = get_path_as_dict(textmodel_json_config, "umt5_config_xxl.json", package=__package__)
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=comfy.text_encoders.t5.T5, enable_attention_masks=True, zero_out_masked=True, model_options=model_options)
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"end": 1, "pad": 0}, model_class=T5, enable_attention_masks=True, zero_out_masked=True, model_options=model_options)
class UMT5XXlTokenizer(sd1_clip.SDTokenizer):
@ -19,7 +17,7 @@ class UMT5XXlTokenizer(sd1_clip.SDTokenizer):
if tokenizer_data is None:
tokenizer_data = {}
tokenizer = tokenizer_data.get("spiece_model", None)
super().__init__(tokenizer, pad_with_end=False, embedding_size=4096, embedding_key='umt5xxl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=0, tokenizer_data=tokenizer_data)
super().__init__(tokenizer, pad_with_end=False, embedding_size=4096, embedding_key='umt5xxl', tokenizer_class=SPieceTokenizer, has_start_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_token=0, tokenizer_data=tokenizer_data, embedding_directory=embedding_directory)
def state_dict(self):
return {"spiece_model": self.tokenizer.serialize_model()}

View File

@ -3,7 +3,7 @@ from typing import Optional
import torch
import torch.nn as nn
import comfy.model_management
from ..model_management import cast_to_device
class WeightAdapterBase:
@ -40,7 +40,7 @@ class WeightAdapterTrainBase(nn.Module):
def weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function):
dora_scale = comfy.model_management.cast_to_device(dora_scale, weight.device, intermediate_dtype)
dora_scale = cast_to_device(dora_scale, weight.device, intermediate_dtype)
lora_diff *= alpha
weight_calc = weight + function(lora_diff).type(weight.dtype)

View File

@ -2,9 +2,10 @@ import logging
from typing import Optional
import torch
import comfy.model_management
from ..model_management import cast_to_device
from .base import WeightAdapterBase, weight_decompose
logger = logging.getLogger(__name__)
class BOFTAdapter(WeightAdapterBase):
name = "boft"
@ -62,9 +63,9 @@ class BOFTAdapter(WeightAdapterBase):
alpha = v[2]
dora_scale = v[3]
blocks = comfy.model_management.cast_to_device(blocks, weight.device, intermediate_dtype)
blocks = cast_to_device(blocks, weight.device, intermediate_dtype)
if rescale is not None:
rescale = comfy.model_management.cast_to_device(rescale, weight.device, intermediate_dtype)
rescale = cast_to_device(rescale, weight.device, intermediate_dtype)
boft_m, block_num, boft_b, *_ = blocks.shape
@ -105,11 +106,11 @@ class BOFTAdapter(WeightAdapterBase):
inp = inp * rescale
lora_diff = inp - org
lora_diff = comfy.model_management.cast_to_device(lora_diff, weight.device, intermediate_dtype)
lora_diff = cast_to_device(lora_diff, weight.device, intermediate_dtype)
if dora_scale is not None:
weight = weight_decompose(dora_scale, weight, lora_diff, alpha, strength, intermediate_dtype, function)
else:
weight += function((strength * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(self.name, key, e))
logger.error("ERROR {} {} {}".format(self.name, key, e))
return weight

View File

@ -2,9 +2,10 @@ import logging
from typing import Optional
import torch
import comfy.model_management
from ..model_management import cast_to_device
from .base import WeightAdapterBase, weight_decompose
logger = logging.getLogger(__name__)
class GLoRAAdapter(WeightAdapterBase):
name = "glora"
@ -64,10 +65,10 @@ class GLoRAAdapter(WeightAdapterBase):
old_glora = False
rank = v[1].shape[0]
a1 = comfy.model_management.cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype)
a2 = comfy.model_management.cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype)
b1 = comfy.model_management.cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype)
b2 = comfy.model_management.cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype)
a1 = cast_to_device(v[0].flatten(start_dim=1), weight.device, intermediate_dtype)
a2 = cast_to_device(v[1].flatten(start_dim=1), weight.device, intermediate_dtype)
b1 = cast_to_device(v[2].flatten(start_dim=1), weight.device, intermediate_dtype)
b2 = cast_to_device(v[3].flatten(start_dim=1), weight.device, intermediate_dtype)
if v[4] is not None:
alpha = v[4] / rank
@ -89,5 +90,5 @@ class GLoRAAdapter(WeightAdapterBase):
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(self.name, key, e))
logger.error("ERROR {} {} {}".format(self.name, key, e))
return weight

View File

@ -2,9 +2,10 @@ import logging
from typing import Optional
import torch
import comfy.model_management
from ..model_management import cast_to_device
from .base import WeightAdapterBase, weight_decompose
logger = logging.getLogger(__name__)
class LoHaAdapter(WeightAdapterBase):
name = "loha"
@ -75,19 +76,19 @@ class LoHaAdapter(WeightAdapterBase):
t1 = v[5]
t2 = v[6]
m1 = torch.einsum('i j k l, j r, i p -> p r k l',
comfy.model_management.cast_to_device(t1, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype))
cast_to_device(t1, weight.device, intermediate_dtype),
cast_to_device(w1b, weight.device, intermediate_dtype),
cast_to_device(w1a, weight.device, intermediate_dtype))
m2 = torch.einsum('i j k l, j r, i p -> p r k l',
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype))
cast_to_device(t2, weight.device, intermediate_dtype),
cast_to_device(w2b, weight.device, intermediate_dtype),
cast_to_device(w2a, weight.device, intermediate_dtype))
else:
m1 = torch.mm(comfy.model_management.cast_to_device(w1a, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w1b, weight.device, intermediate_dtype))
m2 = torch.mm(comfy.model_management.cast_to_device(w2a, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w2b, weight.device, intermediate_dtype))
m1 = torch.mm(cast_to_device(w1a, weight.device, intermediate_dtype),
cast_to_device(w1b, weight.device, intermediate_dtype))
m2 = torch.mm(cast_to_device(w2a, weight.device, intermediate_dtype),
cast_to_device(w2b, weight.device, intermediate_dtype))
try:
lora_diff = (m1 * m2).reshape(weight.shape)
@ -96,5 +97,5 @@ class LoHaAdapter(WeightAdapterBase):
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(self.name, key, e))
logger.error("ERROR {} {} {}".format(self.name, key, e))
return weight

View File

@ -2,9 +2,10 @@ import logging
from typing import Optional
import torch
import comfy.model_management
from ..model_management import cast_to_device
from .base import WeightAdapterBase, weight_decompose
logger = logging.getLogger(__name__)
class LoKrAdapter(WeightAdapterBase):
name = "lokr"
@ -97,23 +98,23 @@ class LoKrAdapter(WeightAdapterBase):
if w1 is None:
dim = w1_b.shape[0]
w1 = torch.mm(comfy.model_management.cast_to_device(w1_a, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w1_b, weight.device, intermediate_dtype))
w1 = torch.mm(cast_to_device(w1_a, weight.device, intermediate_dtype),
cast_to_device(w1_b, weight.device, intermediate_dtype))
else:
w1 = comfy.model_management.cast_to_device(w1, weight.device, intermediate_dtype)
w1 = cast_to_device(w1, weight.device, intermediate_dtype)
if w2 is None:
dim = w2_b.shape[0]
if t2 is None:
w2 = torch.mm(comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype))
w2 = torch.mm(cast_to_device(w2_a, weight.device, intermediate_dtype),
cast_to_device(w2_b, weight.device, intermediate_dtype))
else:
w2 = torch.einsum('i j k l, j r, i p -> p r k l',
comfy.model_management.cast_to_device(t2, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w2_b, weight.device, intermediate_dtype),
comfy.model_management.cast_to_device(w2_a, weight.device, intermediate_dtype))
cast_to_device(t2, weight.device, intermediate_dtype),
cast_to_device(w2_b, weight.device, intermediate_dtype),
cast_to_device(w2_a, weight.device, intermediate_dtype))
else:
w2 = comfy.model_management.cast_to_device(w2, weight.device, intermediate_dtype)
w2 = cast_to_device(w2, weight.device, intermediate_dtype)
if len(w2.shape) == 4:
w1 = w1.unsqueeze(2).unsqueeze(2)
@ -129,5 +130,5 @@ class LoKrAdapter(WeightAdapterBase):
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(self.name, key, e))
logger.error("ERROR {} {} {}".format(self.name, key, e))
return weight

View File

@ -2,9 +2,10 @@ import logging
from typing import Optional
import torch
import comfy.model_management
from ..model_management import cast_to_device
from .base import WeightAdapterBase, weight_decompose, pad_tensor_to_shape
logger = logging.getLogger(__name__)
class LoRAAdapter(WeightAdapterBase):
name = "lora"
@ -90,10 +91,10 @@ class LoRAAdapter(WeightAdapterBase):
original_weight=None,
):
v = self.weights
mat1 = comfy.model_management.cast_to_device(
mat1 = cast_to_device(
v[0], weight.device, intermediate_dtype
)
mat2 = comfy.model_management.cast_to_device(
mat2 = cast_to_device(
v[1], weight.device, intermediate_dtype
)
dora_scale = v[4]
@ -109,7 +110,7 @@ class LoRAAdapter(WeightAdapterBase):
if v[3] is not None:
# locon mid weights, hopefully the math is fine because I didn't properly test it
mat3 = comfy.model_management.cast_to_device(
mat3 = cast_to_device(
v[3], weight.device, intermediate_dtype
)
final_shape = [mat2.shape[1], mat2.shape[0], mat3.shape[2], mat3.shape[3]]
@ -138,5 +139,5 @@ class LoRAAdapter(WeightAdapterBase):
else:
weight += function(((strength * alpha) * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(self.name, key, e))
logger.error("ERROR {} {} {}".format(self.name, key, e))
return weight

View File

@ -2,9 +2,10 @@ import logging
from typing import Optional
import torch
import comfy.model_management
from ..model_management import cast_to_device
from .base import WeightAdapterBase, weight_decompose
logger = logging.getLogger(__name__)
class OFTAdapter(WeightAdapterBase):
name = "oft"
@ -62,9 +63,9 @@ class OFTAdapter(WeightAdapterBase):
alpha = v[2]
dora_scale = v[3]
blocks = comfy.model_management.cast_to_device(blocks, weight.device, intermediate_dtype)
blocks = cast_to_device(blocks, weight.device, intermediate_dtype)
if rescale is not None:
rescale = comfy.model_management.cast_to_device(rescale, weight.device, intermediate_dtype)
rescale = cast_to_device(rescale, weight.device, intermediate_dtype)
block_num, block_size, *_ = blocks.shape
@ -92,5 +93,5 @@ class OFTAdapter(WeightAdapterBase):
else:
weight += function((strength * lora_diff).type(weight.dtype))
except Exception as e:
logging.error("ERROR {} {} {}".format(self.name, key, e))
logger.error("ERROR {} {} {}".format(self.name, key, e))
return weight

0
comfy_api/__init__.py Normal file
View File

0
comfy_config/__init__.py Normal file
View File

View File

@ -99,6 +99,7 @@ dependencies = [
"colour",
"av>=14.2.0",
"pydantic~=2.0",
"pydantic-settings~=2.0",
"typer",
"ijson",
"scikit-learn>=1.4.1",
@ -249,4 +250,4 @@ exclude = ["*.ipynb"]
allow-direct-references = true
[tool.hatch.build.targets.wheel]
packages = ["comfy/", "comfy_extras/"]
packages = ["comfy/", "comfy_extras/", "comfy_api/", "comfy_api_nodes/", "comfy_config/"]