mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-19 15:02:44 +08:00
Merge upstream/master, keep local README.md
This commit is contained in:
commit
d8f32ffda1
@ -297,6 +297,30 @@ class ControlNet(ControlBase):
|
|||||||
self.model_sampling_current = None
|
self.model_sampling_current = None
|
||||||
super().cleanup()
|
super().cleanup()
|
||||||
|
|
||||||
|
|
||||||
|
class QwenFunControlNet(ControlNet):
|
||||||
|
def get_control(self, x_noisy, t, cond, batched_number, transformer_options):
|
||||||
|
# Fun checkpoints are more sensitive to high strengths in the generic
|
||||||
|
# ControlNet merge path. Use a soft response curve so strength=1.0 stays
|
||||||
|
# unchanged while >1 grows more gently.
|
||||||
|
original_strength = self.strength
|
||||||
|
self.strength = math.sqrt(max(self.strength, 0.0))
|
||||||
|
try:
|
||||||
|
return super().get_control(x_noisy, t, cond, batched_number, transformer_options)
|
||||||
|
finally:
|
||||||
|
self.strength = original_strength
|
||||||
|
|
||||||
|
def pre_run(self, model, percent_to_timestep_function):
|
||||||
|
super().pre_run(model, percent_to_timestep_function)
|
||||||
|
self.set_extra_arg("base_model", model.diffusion_model)
|
||||||
|
|
||||||
|
def copy(self):
|
||||||
|
c = QwenFunControlNet(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
|
||||||
|
c.control_model = self.control_model
|
||||||
|
c.control_model_wrapped = self.control_model_wrapped
|
||||||
|
self.copy_to(c)
|
||||||
|
return c
|
||||||
|
|
||||||
class ControlLoraOps:
|
class ControlLoraOps:
|
||||||
class Linear(torch.nn.Module, comfy.ops.CastWeightBiasOp):
|
class Linear(torch.nn.Module, comfy.ops.CastWeightBiasOp):
|
||||||
def __init__(self, in_features: int, out_features: int, bias: bool = True,
|
def __init__(self, in_features: int, out_features: int, bias: bool = True,
|
||||||
@ -606,6 +630,53 @@ def load_controlnet_qwen_instantx(sd, model_options={}):
|
|||||||
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
control = ControlNet(control_model, compression_ratio=1, latent_format=latent_format, concat_mask=concat_mask, load_device=load_device, manual_cast_dtype=manual_cast_dtype, extra_conds=extra_conds)
|
||||||
return control
|
return control
|
||||||
|
|
||||||
|
|
||||||
|
def load_controlnet_qwen_fun(sd, model_options={}):
|
||||||
|
load_device = comfy.model_management.get_torch_device()
|
||||||
|
weight_dtype = comfy.utils.weight_dtype(sd)
|
||||||
|
unet_dtype = model_options.get("dtype", weight_dtype)
|
||||||
|
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
|
||||||
|
|
||||||
|
operations = model_options.get("custom_operations", None)
|
||||||
|
if operations is None:
|
||||||
|
operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype, disable_fast_fp8=True)
|
||||||
|
|
||||||
|
in_features = sd["control_img_in.weight"].shape[1]
|
||||||
|
inner_dim = sd["control_img_in.weight"].shape[0]
|
||||||
|
|
||||||
|
block_weight = sd["control_blocks.0.attn.to_q.weight"]
|
||||||
|
attention_head_dim = sd["control_blocks.0.attn.norm_q.weight"].shape[0]
|
||||||
|
num_attention_heads = max(1, block_weight.shape[0] // max(1, attention_head_dim))
|
||||||
|
|
||||||
|
model = comfy.ldm.qwen_image.controlnet.QwenImageFunControlNetModel(
|
||||||
|
control_in_features=in_features,
|
||||||
|
inner_dim=inner_dim,
|
||||||
|
num_attention_heads=num_attention_heads,
|
||||||
|
attention_head_dim=attention_head_dim,
|
||||||
|
num_control_blocks=5,
|
||||||
|
main_model_double=60,
|
||||||
|
injection_layers=(0, 12, 24, 36, 48),
|
||||||
|
operations=operations,
|
||||||
|
device=comfy.model_management.unet_offload_device(),
|
||||||
|
dtype=unet_dtype,
|
||||||
|
)
|
||||||
|
model = controlnet_load_state_dict(model, sd)
|
||||||
|
|
||||||
|
latent_format = comfy.latent_formats.Wan21()
|
||||||
|
control = QwenFunControlNet(
|
||||||
|
model,
|
||||||
|
compression_ratio=1,
|
||||||
|
latent_format=latent_format,
|
||||||
|
# Fun checkpoints already expect their own 33-channel context handling.
|
||||||
|
# Enabling generic concat_mask injects an extra mask channel at apply-time
|
||||||
|
# and breaks the intended fallback packing path.
|
||||||
|
concat_mask=False,
|
||||||
|
load_device=load_device,
|
||||||
|
manual_cast_dtype=manual_cast_dtype,
|
||||||
|
extra_conds=[],
|
||||||
|
)
|
||||||
|
return control
|
||||||
|
|
||||||
def convert_mistoline(sd):
|
def convert_mistoline(sd):
|
||||||
return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."})
|
return comfy.utils.state_dict_prefix_replace(sd, {"single_controlnet_blocks.": "controlnet_single_blocks."})
|
||||||
|
|
||||||
@ -683,6 +754,8 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
|
|||||||
return load_controlnet_qwen_instantx(controlnet_data, model_options=model_options)
|
return load_controlnet_qwen_instantx(controlnet_data, model_options=model_options)
|
||||||
elif "controlnet_x_embedder.weight" in controlnet_data:
|
elif "controlnet_x_embedder.weight" in controlnet_data:
|
||||||
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
|
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
|
||||||
|
elif "control_blocks.0.after_proj.weight" in controlnet_data and "control_img_in.weight" in controlnet_data:
|
||||||
|
return load_controlnet_qwen_fun(controlnet_data, model_options=model_options)
|
||||||
|
|
||||||
elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux
|
elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux
|
||||||
return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True, model_options=model_options)
|
return load_controlnet_flux_xlabs_mistoline(convert_mistoline(controlnet_data), mistoline=True, model_options=model_options)
|
||||||
|
|||||||
@ -6,6 +6,8 @@ from torch import Tensor, nn
|
|||||||
|
|
||||||
from .math import attention, rope
|
from .math import attention, rope
|
||||||
|
|
||||||
|
# Fix import for some custom nodes, TODO: delete eventually.
|
||||||
|
RMSNorm = None
|
||||||
|
|
||||||
class EmbedND(nn.Module):
|
class EmbedND(nn.Module):
|
||||||
def __init__(self, dim: int, theta: int, axes_dim: list):
|
def __init__(self, dim: int, theta: int, axes_dim: list):
|
||||||
|
|||||||
@ -2,6 +2,196 @@ import torch
|
|||||||
import math
|
import math
|
||||||
|
|
||||||
from .model import QwenImageTransformer2DModel
|
from .model import QwenImageTransformer2DModel
|
||||||
|
from .model import QwenImageTransformerBlock
|
||||||
|
|
||||||
|
|
||||||
|
class QwenImageFunControlBlock(QwenImageTransformerBlock):
|
||||||
|
def __init__(self, dim, num_attention_heads, attention_head_dim, has_before_proj=False, dtype=None, device=None, operations=None):
|
||||||
|
super().__init__(
|
||||||
|
dim=dim,
|
||||||
|
num_attention_heads=num_attention_heads,
|
||||||
|
attention_head_dim=attention_head_dim,
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
self.has_before_proj = has_before_proj
|
||||||
|
if has_before_proj:
|
||||||
|
self.before_proj = operations.Linear(dim, dim, device=device, dtype=dtype)
|
||||||
|
self.after_proj = operations.Linear(dim, dim, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class QwenImageFunControlNetModel(torch.nn.Module):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
control_in_features=132,
|
||||||
|
inner_dim=3072,
|
||||||
|
num_attention_heads=24,
|
||||||
|
attention_head_dim=128,
|
||||||
|
num_control_blocks=5,
|
||||||
|
main_model_double=60,
|
||||||
|
injection_layers=(0, 12, 24, 36, 48),
|
||||||
|
dtype=None,
|
||||||
|
device=None,
|
||||||
|
operations=None,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.dtype = dtype
|
||||||
|
self.main_model_double = main_model_double
|
||||||
|
self.injection_layers = tuple(injection_layers)
|
||||||
|
# Keep base hint scaling at 1.0 so user-facing strength behaves similarly
|
||||||
|
# to the reference Gen2/VideoX implementation around strength=1.
|
||||||
|
self.hint_scale = 1.0
|
||||||
|
self.control_img_in = operations.Linear(control_in_features, inner_dim, device=device, dtype=dtype)
|
||||||
|
|
||||||
|
self.control_blocks = torch.nn.ModuleList([])
|
||||||
|
for i in range(num_control_blocks):
|
||||||
|
self.control_blocks.append(
|
||||||
|
QwenImageFunControlBlock(
|
||||||
|
dim=inner_dim,
|
||||||
|
num_attention_heads=num_attention_heads,
|
||||||
|
attention_head_dim=attention_head_dim,
|
||||||
|
has_before_proj=(i == 0),
|
||||||
|
dtype=dtype,
|
||||||
|
device=device,
|
||||||
|
operations=operations,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def _process_hint_tokens(self, hint):
|
||||||
|
if hint is None:
|
||||||
|
return None
|
||||||
|
if hint.ndim == 4:
|
||||||
|
hint = hint.unsqueeze(2)
|
||||||
|
|
||||||
|
# Fun checkpoints are trained with 33 latent channels before 2x2 packing:
|
||||||
|
# [control_latent(16), mask(1), inpaint_latent(16)] -> 132 features.
|
||||||
|
# Default behavior (no inpaint input in stock Apply ControlNet) should use
|
||||||
|
# zeros for mask/inpaint branches, matching VideoX fallback semantics.
|
||||||
|
expected_c = self.control_img_in.weight.shape[1] // 4
|
||||||
|
if hint.shape[1] == 16 and expected_c == 33:
|
||||||
|
zeros_mask = torch.zeros_like(hint[:, :1])
|
||||||
|
zeros_inpaint = torch.zeros_like(hint)
|
||||||
|
hint = torch.cat([hint, zeros_mask, zeros_inpaint], dim=1)
|
||||||
|
|
||||||
|
bs, c, t, h, w = hint.shape
|
||||||
|
hidden_states = torch.nn.functional.pad(hint, (0, w % 2, 0, h % 2))
|
||||||
|
orig_shape = hidden_states.shape
|
||||||
|
hidden_states = hidden_states.view(
|
||||||
|
orig_shape[0],
|
||||||
|
orig_shape[1],
|
||||||
|
orig_shape[-3],
|
||||||
|
orig_shape[-2] // 2,
|
||||||
|
2,
|
||||||
|
orig_shape[-1] // 2,
|
||||||
|
2,
|
||||||
|
)
|
||||||
|
hidden_states = hidden_states.permute(0, 2, 3, 5, 1, 4, 6)
|
||||||
|
hidden_states = hidden_states.reshape(
|
||||||
|
bs,
|
||||||
|
t * ((h + 1) // 2) * ((w + 1) // 2),
|
||||||
|
c * 4,
|
||||||
|
)
|
||||||
|
|
||||||
|
expected_in = self.control_img_in.weight.shape[1]
|
||||||
|
cur_in = hidden_states.shape[-1]
|
||||||
|
if cur_in < expected_in:
|
||||||
|
pad = torch.zeros(
|
||||||
|
(hidden_states.shape[0], hidden_states.shape[1], expected_in - cur_in),
|
||||||
|
device=hidden_states.device,
|
||||||
|
dtype=hidden_states.dtype,
|
||||||
|
)
|
||||||
|
hidden_states = torch.cat([hidden_states, pad], dim=-1)
|
||||||
|
elif cur_in > expected_in:
|
||||||
|
hidden_states = hidden_states[:, :, :expected_in]
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
timesteps,
|
||||||
|
context,
|
||||||
|
attention_mask=None,
|
||||||
|
guidance: torch.Tensor = None,
|
||||||
|
hint=None,
|
||||||
|
transformer_options={},
|
||||||
|
base_model=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
if base_model is None:
|
||||||
|
raise RuntimeError("Qwen Fun ControlNet requires a QwenImage base model at runtime.")
|
||||||
|
|
||||||
|
encoder_hidden_states_mask = attention_mask
|
||||||
|
# Keep attention mask disabled inside Fun control blocks to mirror
|
||||||
|
# VideoX behavior (they rely on seq lengths for RoPE, not masked attention).
|
||||||
|
encoder_hidden_states_mask = None
|
||||||
|
|
||||||
|
hidden_states, img_ids, _ = base_model.process_img(x)
|
||||||
|
hint_tokens = self._process_hint_tokens(hint)
|
||||||
|
if hint_tokens is None:
|
||||||
|
raise RuntimeError("Qwen Fun ControlNet requires a control hint image.")
|
||||||
|
|
||||||
|
if hint_tokens.shape[1] != hidden_states.shape[1]:
|
||||||
|
max_tokens = min(hint_tokens.shape[1], hidden_states.shape[1])
|
||||||
|
hint_tokens = hint_tokens[:, :max_tokens]
|
||||||
|
hidden_states = hidden_states[:, :max_tokens]
|
||||||
|
img_ids = img_ids[:, :max_tokens]
|
||||||
|
|
||||||
|
txt_start = round(
|
||||||
|
max(
|
||||||
|
((x.shape[-1] + (base_model.patch_size // 2)) // base_model.patch_size) // 2,
|
||||||
|
((x.shape[-2] + (base_model.patch_size // 2)) // base_model.patch_size) // 2,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
||||||
|
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||||
|
image_rotary_emb = base_model.pe_embedder(ids).to(x.dtype).contiguous()
|
||||||
|
|
||||||
|
hidden_states = base_model.img_in(hidden_states)
|
||||||
|
encoder_hidden_states = base_model.txt_norm(context)
|
||||||
|
encoder_hidden_states = base_model.txt_in(encoder_hidden_states)
|
||||||
|
|
||||||
|
if guidance is not None:
|
||||||
|
guidance = guidance * 1000
|
||||||
|
|
||||||
|
temb = (
|
||||||
|
base_model.time_text_embed(timesteps, hidden_states)
|
||||||
|
if guidance is None
|
||||||
|
else base_model.time_text_embed(timesteps, guidance, hidden_states)
|
||||||
|
)
|
||||||
|
|
||||||
|
c = self.control_img_in(hint_tokens)
|
||||||
|
|
||||||
|
for i, block in enumerate(self.control_blocks):
|
||||||
|
if i == 0:
|
||||||
|
c_in = block.before_proj(c) + hidden_states
|
||||||
|
all_c = []
|
||||||
|
else:
|
||||||
|
all_c = list(torch.unbind(c, dim=0))
|
||||||
|
c_in = all_c.pop(-1)
|
||||||
|
|
||||||
|
encoder_hidden_states, c_out = block(
|
||||||
|
hidden_states=c_in,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
|
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
||||||
|
temb=temb,
|
||||||
|
image_rotary_emb=image_rotary_emb,
|
||||||
|
transformer_options=transformer_options,
|
||||||
|
)
|
||||||
|
|
||||||
|
c_skip = block.after_proj(c_out) * self.hint_scale
|
||||||
|
all_c += [c_skip, c_out]
|
||||||
|
c = torch.stack(all_c, dim=0)
|
||||||
|
|
||||||
|
hints = torch.unbind(c, dim=0)[:-1]
|
||||||
|
|
||||||
|
controlnet_block_samples = [None] * self.main_model_double
|
||||||
|
for local_idx, base_idx in enumerate(self.injection_layers):
|
||||||
|
if local_idx < len(hints) and base_idx < len(controlnet_block_samples):
|
||||||
|
controlnet_block_samples[base_idx] = hints[local_idx]
|
||||||
|
|
||||||
|
return {"input": controlnet_block_samples}
|
||||||
|
|
||||||
|
|
||||||
class QwenImageControlNetModel(QwenImageTransformer2DModel):
|
class QwenImageControlNetModel(QwenImageTransformer2DModel):
|
||||||
|
|||||||
@ -171,8 +171,9 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
|
|
||||||
def process_tokens(self, tokens, device):
|
def process_tokens(self, tokens, device):
|
||||||
end_token = self.special_tokens.get("end", None)
|
end_token = self.special_tokens.get("end", None)
|
||||||
|
pad_token = self.special_tokens.get("pad", -1)
|
||||||
if end_token is None:
|
if end_token is None:
|
||||||
cmp_token = self.special_tokens.get("pad", -1)
|
cmp_token = pad_token
|
||||||
else:
|
else:
|
||||||
cmp_token = end_token
|
cmp_token = end_token
|
||||||
|
|
||||||
@ -186,15 +187,21 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
other_embeds = []
|
other_embeds = []
|
||||||
eos = False
|
eos = False
|
||||||
index = 0
|
index = 0
|
||||||
|
left_pad = False
|
||||||
for y in x:
|
for y in x:
|
||||||
if isinstance(y, numbers.Integral):
|
if isinstance(y, numbers.Integral):
|
||||||
if eos:
|
token = int(y)
|
||||||
|
if index == 0 and token == pad_token:
|
||||||
|
left_pad = True
|
||||||
|
|
||||||
|
if eos or (left_pad and token == pad_token):
|
||||||
attention_mask.append(0)
|
attention_mask.append(0)
|
||||||
else:
|
else:
|
||||||
attention_mask.append(1)
|
attention_mask.append(1)
|
||||||
token = int(y)
|
left_pad = False
|
||||||
|
|
||||||
tokens_temp += [token]
|
tokens_temp += [token]
|
||||||
if not eos and token == cmp_token:
|
if not eos and token == cmp_token and not left_pad:
|
||||||
if end_token is None:
|
if end_token is None:
|
||||||
attention_mask[-1] = 0
|
attention_mask[-1] = 0
|
||||||
eos = True
|
eos = True
|
||||||
|
|||||||
@ -10,7 +10,6 @@ import comfy.utils
|
|||||||
def sample_manual_loop_no_classes(
|
def sample_manual_loop_no_classes(
|
||||||
model,
|
model,
|
||||||
ids=None,
|
ids=None,
|
||||||
paddings=[],
|
|
||||||
execution_dtype=None,
|
execution_dtype=None,
|
||||||
cfg_scale: float = 2.0,
|
cfg_scale: float = 2.0,
|
||||||
temperature: float = 0.85,
|
temperature: float = 0.85,
|
||||||
@ -36,9 +35,6 @@ def sample_manual_loop_no_classes(
|
|||||||
|
|
||||||
embeds, attention_mask, num_tokens, embeds_info = model.process_tokens(ids, device)
|
embeds, attention_mask, num_tokens, embeds_info = model.process_tokens(ids, device)
|
||||||
embeds_batch = embeds.shape[0]
|
embeds_batch = embeds.shape[0]
|
||||||
for i, t in enumerate(paddings):
|
|
||||||
attention_mask[i, :t] = 0
|
|
||||||
attention_mask[i, t:] = 1
|
|
||||||
|
|
||||||
output_audio_codes = []
|
output_audio_codes = []
|
||||||
past_key_values = []
|
past_key_values = []
|
||||||
@ -135,13 +131,11 @@ def generate_audio_codes(model, positive, negative, min_tokens=1, max_tokens=102
|
|||||||
pos_pad = (len(negative) - len(positive))
|
pos_pad = (len(negative) - len(positive))
|
||||||
positive = [model.special_tokens["pad"]] * pos_pad + positive
|
positive = [model.special_tokens["pad"]] * pos_pad + positive
|
||||||
|
|
||||||
paddings = [pos_pad, neg_pad]
|
|
||||||
ids = [positive, negative]
|
ids = [positive, negative]
|
||||||
else:
|
else:
|
||||||
paddings = []
|
|
||||||
ids = [positive]
|
ids = [positive]
|
||||||
|
|
||||||
return sample_manual_loop_no_classes(model, ids, paddings, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens)
|
return sample_manual_loop_no_classes(model, ids, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, top_k=top_k, min_p=min_p, seed=seed, min_tokens=min_tokens, max_new_tokens=max_tokens)
|
||||||
|
|
||||||
|
|
||||||
class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
|
class ACE15Tokenizer(sd1_clip.SD1Tokenizer):
|
||||||
|
|||||||
@ -25,7 +25,7 @@ def ltxv_te(*args, **kwargs):
|
|||||||
class Gemma3_12BTokenizer(sd1_clip.SDTokenizer):
|
class Gemma3_12BTokenizer(sd1_clip.SDTokenizer):
|
||||||
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
def __init__(self, embedding_directory=None, tokenizer_data={}):
|
||||||
tokenizer = tokenizer_data.get("spiece_model", None)
|
tokenizer = tokenizer_data.get("spiece_model", None)
|
||||||
super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=1, disable_weights=True, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
|
super().__init__(tokenizer, pad_with_end=False, embedding_size=3840, embedding_key='gemma3_12b', tokenizer_class=SPieceTokenizer, has_end_token=False, pad_to_max_length=False, max_length=99999999, min_length=512, pad_left=True, disable_weights=True, tokenizer_args={"add_bos": True, "add_eos": False}, tokenizer_data=tokenizer_data)
|
||||||
|
|
||||||
def state_dict(self):
|
def state_dict(self):
|
||||||
return {"spiece_model": self.tokenizer.serialize_model()}
|
return {"spiece_model": self.tokenizer.serialize_model()}
|
||||||
@ -97,6 +97,7 @@ class LTXAVTEModel(torch.nn.Module):
|
|||||||
token_weight_pairs = token_weight_pairs["gemma3_12b"]
|
token_weight_pairs = token_weight_pairs["gemma3_12b"]
|
||||||
|
|
||||||
out, pooled, extra = self.gemma3_12b.encode_token_weights(token_weight_pairs)
|
out, pooled, extra = self.gemma3_12b.encode_token_weights(token_weight_pairs)
|
||||||
|
out = out[:, :, -torch.sum(extra["attention_mask"]).item():]
|
||||||
out_device = out.device
|
out_device = out.device
|
||||||
if comfy.model_management.should_use_bf16(self.execution_device):
|
if comfy.model_management.should_use_bf16(self.execution_device):
|
||||||
out = out.to(device=self.execution_device, dtype=torch.bfloat16)
|
out = out.to(device=self.execution_device, dtype=torch.bfloat16)
|
||||||
@ -138,6 +139,7 @@ class LTXAVTEModel(torch.nn.Module):
|
|||||||
|
|
||||||
token_weight_pairs = token_weight_pairs.get("gemma3_12b", [])
|
token_weight_pairs = token_weight_pairs.get("gemma3_12b", [])
|
||||||
num_tokens = sum(map(lambda a: len(a), token_weight_pairs))
|
num_tokens = sum(map(lambda a: len(a), token_weight_pairs))
|
||||||
|
num_tokens = max(num_tokens, 64)
|
||||||
return num_tokens * constant * 1024 * 1024
|
return num_tokens * constant * 1024 * 1024
|
||||||
|
|
||||||
def ltxav_te(dtype_llama=None, llama_quantization_metadata=None):
|
def ltxav_te(dtype_llama=None, llama_quantization_metadata=None):
|
||||||
|
|||||||
@ -1,4 +1,4 @@
|
|||||||
comfyui-frontend-package==1.38.13
|
comfyui-frontend-package==1.38.14
|
||||||
comfyui-workflow-templates==0.8.38
|
comfyui-workflow-templates==0.8.38
|
||||||
comfyui-embedded-docs==0.4.1
|
comfyui-embedded-docs==0.4.1
|
||||||
torch
|
torch
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user