mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-03 13:52:31 +08:00
Merge upstream/master, keep local README.md
This commit is contained in:
commit
178cc311b6
2
.gitignore
vendored
2
.gitignore
vendored
@ -11,7 +11,7 @@ extra_model_paths.yaml
|
|||||||
/.vs
|
/.vs
|
||||||
.vscode/
|
.vscode/
|
||||||
.idea/
|
.idea/
|
||||||
venv/
|
venv*/
|
||||||
.venv/
|
.venv/
|
||||||
/web/extensions/*
|
/web/extensions/*
|
||||||
!/web/extensions/logging.js.example
|
!/web/extensions/logging.js.example
|
||||||
|
|||||||
@ -179,8 +179,8 @@ class LLMAdapter(nn.Module):
|
|||||||
if source_attention_mask.ndim == 2:
|
if source_attention_mask.ndim == 2:
|
||||||
source_attention_mask = source_attention_mask.unsqueeze(1).unsqueeze(1)
|
source_attention_mask = source_attention_mask.unsqueeze(1).unsqueeze(1)
|
||||||
|
|
||||||
x = self.in_proj(self.embed(target_input_ids))
|
|
||||||
context = source_hidden_states
|
context = source_hidden_states
|
||||||
|
x = self.in_proj(self.embed(target_input_ids, out_dtype=context.dtype))
|
||||||
position_ids = torch.arange(x.shape[1], device=x.device).unsqueeze(0)
|
position_ids = torch.arange(x.shape[1], device=x.device).unsqueeze(0)
|
||||||
position_ids_context = torch.arange(context.shape[1], device=x.device).unsqueeze(0)
|
position_ids_context = torch.arange(context.shape[1], device=x.device).unsqueeze(0)
|
||||||
position_embeddings = self.rotary_emb(x, position_ids)
|
position_embeddings = self.rotary_emb(x, position_ids)
|
||||||
|
|||||||
@ -152,6 +152,7 @@ class Chroma(nn.Module):
|
|||||||
transformer_options={},
|
transformer_options={},
|
||||||
attn_mask: Tensor = None,
|
attn_mask: Tensor = None,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
|
transformer_options = transformer_options.copy()
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
|
|
||||||
# running on sequences img
|
# running on sequences img
|
||||||
@ -228,6 +229,7 @@ class Chroma(nn.Module):
|
|||||||
|
|
||||||
transformer_options["total_blocks"] = len(self.single_blocks)
|
transformer_options["total_blocks"] = len(self.single_blocks)
|
||||||
transformer_options["block_type"] = "single"
|
transformer_options["block_type"] = "single"
|
||||||
|
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
|
||||||
for i, block in enumerate(self.single_blocks):
|
for i, block in enumerate(self.single_blocks):
|
||||||
transformer_options["block_index"] = i
|
transformer_options["block_index"] = i
|
||||||
if i not in self.skip_dit:
|
if i not in self.skip_dit:
|
||||||
|
|||||||
@ -196,6 +196,9 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
else:
|
else:
|
||||||
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
|
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec
|
||||||
|
|
||||||
|
transformer_patches = transformer_options.get("patches", {})
|
||||||
|
extra_options = transformer_options.copy()
|
||||||
|
|
||||||
# prepare image for attention
|
# prepare image for attention
|
||||||
img_modulated = self.img_norm1(img)
|
img_modulated = self.img_norm1(img)
|
||||||
img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img)
|
img_modulated = apply_mod(img_modulated, (1 + img_mod1.scale), img_mod1.shift, modulation_dims_img)
|
||||||
@ -224,6 +227,12 @@ class DoubleStreamBlock(nn.Module):
|
|||||||
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||||
del q, k, v
|
del q, k, v
|
||||||
|
|
||||||
|
if "attn1_output_patch" in transformer_patches:
|
||||||
|
extra_options["img_slice"] = [txt.shape[1], attn.shape[1]]
|
||||||
|
patch = transformer_patches["attn1_output_patch"]
|
||||||
|
for p in patch:
|
||||||
|
attn = p(attn, extra_options)
|
||||||
|
|
||||||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]
|
||||||
|
|
||||||
# calculate the img bloks
|
# calculate the img bloks
|
||||||
@ -303,6 +312,9 @@ class SingleStreamBlock(nn.Module):
|
|||||||
else:
|
else:
|
||||||
mod = vec
|
mod = vec
|
||||||
|
|
||||||
|
transformer_patches = transformer_options.get("patches", {})
|
||||||
|
extra_options = transformer_options.copy()
|
||||||
|
|
||||||
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim_first], dim=-1)
|
qkv, mlp = torch.split(self.linear1(apply_mod(self.pre_norm(x), (1 + mod.scale), mod.shift, modulation_dims)), [3 * self.hidden_size, self.mlp_hidden_dim_first], dim=-1)
|
||||||
|
|
||||||
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||||
@ -312,6 +324,12 @@ class SingleStreamBlock(nn.Module):
|
|||||||
# compute attention
|
# compute attention
|
||||||
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||||
del q, k, v
|
del q, k, v
|
||||||
|
|
||||||
|
if "attn1_output_patch" in transformer_patches:
|
||||||
|
patch = transformer_patches["attn1_output_patch"]
|
||||||
|
for p in patch:
|
||||||
|
attn = p(attn, extra_options)
|
||||||
|
|
||||||
# compute activation in mlp stream, cat again and run second linear layer
|
# compute activation in mlp stream, cat again and run second linear layer
|
||||||
if self.yak_mlp:
|
if self.yak_mlp:
|
||||||
mlp = self.mlp_act(mlp[..., self.mlp_hidden_dim_first // 2:]) * mlp[..., :self.mlp_hidden_dim_first // 2]
|
mlp = self.mlp_act(mlp[..., self.mlp_hidden_dim_first // 2:]) * mlp[..., :self.mlp_hidden_dim_first // 2]
|
||||||
|
|||||||
@ -142,6 +142,7 @@ class Flux(nn.Module):
|
|||||||
attn_mask: Tensor = None,
|
attn_mask: Tensor = None,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
|
|
||||||
|
transformer_options = transformer_options.copy()
|
||||||
patches = transformer_options.get("patches", {})
|
patches = transformer_options.get("patches", {})
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
if img.ndim != 3 or txt.ndim != 3:
|
if img.ndim != 3 or txt.ndim != 3:
|
||||||
@ -231,6 +232,7 @@ class Flux(nn.Module):
|
|||||||
|
|
||||||
transformer_options["total_blocks"] = len(self.single_blocks)
|
transformer_options["total_blocks"] = len(self.single_blocks)
|
||||||
transformer_options["block_type"] = "single"
|
transformer_options["block_type"] = "single"
|
||||||
|
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
|
||||||
for i, block in enumerate(self.single_blocks):
|
for i, block in enumerate(self.single_blocks):
|
||||||
transformer_options["block_index"] = i
|
transformer_options["block_index"] = i
|
||||||
if ("single_block", i) in blocks_replace:
|
if ("single_block", i) in blocks_replace:
|
||||||
|
|||||||
@ -304,6 +304,7 @@ class HunyuanVideo(nn.Module):
|
|||||||
control=None,
|
control=None,
|
||||||
transformer_options={},
|
transformer_options={},
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
|
transformer_options = transformer_options.copy()
|
||||||
patches_replace = transformer_options.get("patches_replace", {})
|
patches_replace = transformer_options.get("patches_replace", {})
|
||||||
|
|
||||||
initial_shape = list(img.shape)
|
initial_shape = list(img.shape)
|
||||||
@ -416,6 +417,7 @@ class HunyuanVideo(nn.Module):
|
|||||||
|
|
||||||
transformer_options["total_blocks"] = len(self.single_blocks)
|
transformer_options["total_blocks"] = len(self.single_blocks)
|
||||||
transformer_options["block_type"] = "single"
|
transformer_options["block_type"] = "single"
|
||||||
|
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
|
||||||
for i, block in enumerate(self.single_blocks):
|
for i, block in enumerate(self.single_blocks):
|
||||||
transformer_options["block_index"] = i
|
transformer_options["block_index"] = i
|
||||||
if ("single_block", i) in blocks_replace:
|
if ("single_block", i) in blocks_replace:
|
||||||
|
|||||||
@ -178,10 +178,7 @@ class BaseModel(torch.nn.Module):
|
|||||||
xc = torch.cat([xc] + [comfy.model_management.cast_to_device(c_concat, xc.device, xc.dtype)], dim=1)
|
xc = torch.cat([xc] + [comfy.model_management.cast_to_device(c_concat, xc.device, xc.dtype)], dim=1)
|
||||||
|
|
||||||
context = c_crossattn
|
context = c_crossattn
|
||||||
dtype = self.get_dtype()
|
dtype = self.get_dtype_inference()
|
||||||
|
|
||||||
if self.manual_cast_dtype is not None:
|
|
||||||
dtype = self.manual_cast_dtype
|
|
||||||
|
|
||||||
xc = xc.to(dtype)
|
xc = xc.to(dtype)
|
||||||
device = xc.device
|
device = xc.device
|
||||||
@ -218,6 +215,13 @@ class BaseModel(torch.nn.Module):
|
|||||||
def get_dtype(self):
|
def get_dtype(self):
|
||||||
return self.diffusion_model.dtype
|
return self.diffusion_model.dtype
|
||||||
|
|
||||||
|
def get_dtype_inference(self):
|
||||||
|
dtype = self.get_dtype()
|
||||||
|
|
||||||
|
if self.manual_cast_dtype is not None:
|
||||||
|
dtype = self.manual_cast_dtype
|
||||||
|
return dtype
|
||||||
|
|
||||||
def encode_adm(self, **kwargs):
|
def encode_adm(self, **kwargs):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -372,9 +376,7 @@ class BaseModel(torch.nn.Module):
|
|||||||
input_shapes += shape
|
input_shapes += shape
|
||||||
|
|
||||||
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
|
if comfy.model_management.xformers_enabled() or comfy.model_management.pytorch_attention_flash_attention():
|
||||||
dtype = self.get_dtype()
|
dtype = self.get_dtype_inference()
|
||||||
if self.manual_cast_dtype is not None:
|
|
||||||
dtype = self.manual_cast_dtype
|
|
||||||
#TODO: this needs to be tweaked
|
#TODO: this needs to be tweaked
|
||||||
area = sum(map(lambda input_shape: input_shape[0] * math.prod(input_shape[2:]), input_shapes))
|
area = sum(map(lambda input_shape: input_shape[0] * math.prod(input_shape[2:]), input_shapes))
|
||||||
return (area * comfy.model_management.dtype_size(dtype) * 0.01 * self.memory_usage_factor) * (1024 * 1024)
|
return (area * comfy.model_management.dtype_size(dtype) * 0.01 * self.memory_usage_factor) * (1024 * 1024)
|
||||||
@ -1165,7 +1167,7 @@ class Anima(BaseModel):
|
|||||||
t5xxl_ids = t5xxl_ids.unsqueeze(0)
|
t5xxl_ids = t5xxl_ids.unsqueeze(0)
|
||||||
|
|
||||||
if torch.is_inference_mode_enabled(): # if not we are training
|
if torch.is_inference_mode_enabled(): # if not we are training
|
||||||
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype()), t5xxl_ids.to(device=device), t5xxl_weights=t5xxl_weights.to(device=device, dtype=self.get_dtype()))
|
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype_inference()), t5xxl_ids.to(device=device), t5xxl_weights=t5xxl_weights.to(device=device, dtype=self.get_dtype_inference()))
|
||||||
else:
|
else:
|
||||||
out['t5xxl_ids'] = comfy.conds.CONDRegular(t5xxl_ids)
|
out['t5xxl_ids'] = comfy.conds.CONDRegular(t5xxl_ids)
|
||||||
out['t5xxl_weights'] = comfy.conds.CONDRegular(t5xxl_weights)
|
out['t5xxl_weights'] = comfy.conds.CONDRegular(t5xxl_weights)
|
||||||
|
|||||||
@ -406,13 +406,16 @@ class ModelPatcher:
|
|||||||
def memory_required(self, input_shape):
|
def memory_required(self, input_shape):
|
||||||
return self.model.memory_required(input_shape=input_shape)
|
return self.model.memory_required(input_shape=input_shape)
|
||||||
|
|
||||||
|
def disable_model_cfg1_optimization(self):
|
||||||
|
self.model_options["disable_cfg1_optimization"] = True
|
||||||
|
|
||||||
def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_optimization=False):
|
def set_model_sampler_cfg_function(self, sampler_cfg_function, disable_cfg1_optimization=False):
|
||||||
if len(inspect.signature(sampler_cfg_function).parameters) == 3:
|
if len(inspect.signature(sampler_cfg_function).parameters) == 3:
|
||||||
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
|
self.model_options["sampler_cfg_function"] = lambda args: sampler_cfg_function(args["cond"], args["uncond"], args["cond_scale"]) #Old way
|
||||||
else:
|
else:
|
||||||
self.model_options["sampler_cfg_function"] = sampler_cfg_function
|
self.model_options["sampler_cfg_function"] = sampler_cfg_function
|
||||||
if disable_cfg1_optimization:
|
if disable_cfg1_optimization:
|
||||||
self.model_options["disable_cfg1_optimization"] = True
|
self.disable_model_cfg1_optimization()
|
||||||
|
|
||||||
def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_optimization=False):
|
def set_model_sampler_post_cfg_function(self, post_cfg_function, disable_cfg1_optimization=False):
|
||||||
self.model_options = set_model_options_post_cfg_function(self.model_options, post_cfg_function, disable_cfg1_optimization)
|
self.model_options = set_model_options_post_cfg_function(self.model_options, post_cfg_function, disable_cfg1_optimization)
|
||||||
|
|||||||
23
comfy/ops.py
23
comfy/ops.py
@ -21,7 +21,6 @@ import logging
|
|||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram
|
from comfy.cli_args import args, PerformanceFeature, enables_dynamic_vram
|
||||||
import comfy.float
|
import comfy.float
|
||||||
import comfy.rmsnorm
|
|
||||||
import json
|
import json
|
||||||
import comfy.memory_management
|
import comfy.memory_management
|
||||||
import comfy.pinned_memory
|
import comfy.pinned_memory
|
||||||
@ -80,7 +79,7 @@ def cast_to_input(weight, input, non_blocking=False, copy=True):
|
|||||||
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy)
|
||||||
|
|
||||||
|
|
||||||
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype):
|
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant):
|
||||||
offload_stream = None
|
offload_stream = None
|
||||||
xfer_dest = None
|
xfer_dest = None
|
||||||
|
|
||||||
@ -171,10 +170,10 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
|
|||||||
#FIXME: this is not accurate, we need to be sensitive to the compute dtype
|
#FIXME: this is not accurate, we need to be sensitive to the compute dtype
|
||||||
x = lowvram_fn(x)
|
x = lowvram_fn(x)
|
||||||
if (isinstance(orig, QuantizedTensor) and
|
if (isinstance(orig, QuantizedTensor) and
|
||||||
(orig.dtype == dtype and len(fns) == 0 or update_weight)):
|
(want_requant and len(fns) == 0 or update_weight)):
|
||||||
seed = comfy.utils.string_to_seed(s.seed_key)
|
seed = comfy.utils.string_to_seed(s.seed_key)
|
||||||
y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed)
|
y = QuantizedTensor.from_float(x, s.layout_type, scale="recalculate", stochastic_rounding=seed)
|
||||||
if orig.dtype == dtype and len(fns) == 0:
|
if want_requant and len(fns) == 0:
|
||||||
#The layer actually wants our freshly saved QT
|
#The layer actually wants our freshly saved QT
|
||||||
x = y
|
x = y
|
||||||
elif update_weight:
|
elif update_weight:
|
||||||
@ -195,7 +194,7 @@ def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compu
|
|||||||
return weight, bias, (offload_stream, device if signature is not None else None, None)
|
return weight, bias, (offload_stream, device if signature is not None else None, None)
|
||||||
|
|
||||||
|
|
||||||
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None):
|
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, offloadable=False, compute_dtype=None, want_requant=False):
|
||||||
# NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass
|
# NOTE: offloadable=False is a a legacy and if you are a custom node author reading this please pass
|
||||||
# offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This
|
# offloadable=True and call uncast_bias_weight() after your last usage of the weight/bias. This
|
||||||
# will add async-offload support to your cast and improve performance.
|
# will add async-offload support to your cast and improve performance.
|
||||||
@ -213,7 +212,7 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
|||||||
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
||||||
|
|
||||||
if hasattr(s, "_v"):
|
if hasattr(s, "_v"):
|
||||||
return cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype)
|
return cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant)
|
||||||
|
|
||||||
if offloadable and (device != s.weight.device or
|
if offloadable and (device != s.weight.device or
|
||||||
(s.bias is not None and device != s.bias.device)):
|
(s.bias is not None and device != s.bias.device)):
|
||||||
@ -463,7 +462,7 @@ class disable_weight_init:
|
|||||||
else:
|
else:
|
||||||
return super().forward(*args, **kwargs)
|
return super().forward(*args, **kwargs)
|
||||||
|
|
||||||
class RMSNorm(comfy.rmsnorm.RMSNorm, CastWeightBiasOp):
|
class RMSNorm(torch.nn.RMSNorm, CastWeightBiasOp):
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
self.bias = None
|
self.bias = None
|
||||||
return None
|
return None
|
||||||
@ -475,8 +474,7 @@ class disable_weight_init:
|
|||||||
weight = None
|
weight = None
|
||||||
bias = None
|
bias = None
|
||||||
offload_stream = None
|
offload_stream = None
|
||||||
x = comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated
|
x = torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
|
||||||
# x = torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps)
|
|
||||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@ -852,8 +850,8 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
def _forward(self, input, weight, bias):
|
def _forward(self, input, weight, bias):
|
||||||
return torch.nn.functional.linear(input, weight, bias)
|
return torch.nn.functional.linear(input, weight, bias)
|
||||||
|
|
||||||
def forward_comfy_cast_weights(self, input, compute_dtype=None):
|
def forward_comfy_cast_weights(self, input, compute_dtype=None, want_requant=False):
|
||||||
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True, compute_dtype=compute_dtype)
|
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True, compute_dtype=compute_dtype, want_requant=want_requant)
|
||||||
x = self._forward(input, weight, bias)
|
x = self._forward(input, weight, bias)
|
||||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||||
return x
|
return x
|
||||||
@ -883,8 +881,7 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
scale = comfy.model_management.cast_to_device(scale, input.device, None)
|
scale = comfy.model_management.cast_to_device(scale, input.device, None)
|
||||||
input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale)
|
input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale)
|
||||||
|
|
||||||
|
output = self.forward_comfy_cast_weights(input, compute_dtype, want_requant=isinstance(input, QuantizedTensor))
|
||||||
output = self.forward_comfy_cast_weights(input, compute_dtype)
|
|
||||||
|
|
||||||
# Reshape output back to 3D if input was 3D
|
# Reshape output back to 3D if input was 3D
|
||||||
if reshaped_3d:
|
if reshaped_3d:
|
||||||
|
|||||||
@ -1,57 +1,10 @@
|
|||||||
import torch
|
import torch
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import numbers
|
|
||||||
import logging
|
|
||||||
|
|
||||||
RMSNorm = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
rms_norm_torch = torch.nn.functional.rms_norm
|
|
||||||
RMSNorm = torch.nn.RMSNorm
|
|
||||||
except:
|
|
||||||
rms_norm_torch = None
|
|
||||||
logging.warning("Please update pytorch to use native RMSNorm")
|
|
||||||
|
|
||||||
|
RMSNorm = torch.nn.RMSNorm
|
||||||
|
|
||||||
def rms_norm(x, weight=None, eps=1e-6):
|
def rms_norm(x, weight=None, eps=1e-6):
|
||||||
if rms_norm_torch is not None and not (torch.jit.is_tracing() or torch.jit.is_scripting()):
|
if weight is None:
|
||||||
if weight is None:
|
return torch.nn.functional.rms_norm(x, (x.shape[-1],), eps=eps)
|
||||||
return rms_norm_torch(x, (x.shape[-1],), eps=eps)
|
|
||||||
else:
|
|
||||||
return rms_norm_torch(x, weight.shape, weight=comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
|
|
||||||
else:
|
else:
|
||||||
r = x * torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps)
|
return torch.nn.functional.rms_norm(x, weight.shape, weight=comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device), eps=eps)
|
||||||
if weight is None:
|
|
||||||
return r
|
|
||||||
else:
|
|
||||||
return r * comfy.model_management.cast_to(weight, dtype=x.dtype, device=x.device)
|
|
||||||
|
|
||||||
|
|
||||||
if RMSNorm is None:
|
|
||||||
class RMSNorm(torch.nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
normalized_shape,
|
|
||||||
eps=1e-6,
|
|
||||||
elementwise_affine=True,
|
|
||||||
device=None,
|
|
||||||
dtype=None,
|
|
||||||
):
|
|
||||||
factory_kwargs = {"device": device, "dtype": dtype}
|
|
||||||
super().__init__()
|
|
||||||
if isinstance(normalized_shape, numbers.Integral):
|
|
||||||
# mypy error: incompatible types in assignment
|
|
||||||
normalized_shape = (normalized_shape,) # type: ignore[assignment]
|
|
||||||
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
|
|
||||||
self.eps = eps
|
|
||||||
self.elementwise_affine = elementwise_affine
|
|
||||||
if self.elementwise_affine:
|
|
||||||
self.weight = torch.nn.Parameter(
|
|
||||||
torch.empty(self.normalized_shape, **factory_kwargs)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.register_parameter("weight", None)
|
|
||||||
self.bias = None
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
return rms_norm(x, self.weight, self.eps)
|
|
||||||
|
|||||||
@ -75,6 +75,12 @@ class NumberDisplay(str, Enum):
|
|||||||
slider = "slider"
|
slider = "slider"
|
||||||
|
|
||||||
|
|
||||||
|
class ControlAfterGenerate(str, Enum):
|
||||||
|
fixed = "fixed"
|
||||||
|
increment = "increment"
|
||||||
|
decrement = "decrement"
|
||||||
|
randomize = "randomize"
|
||||||
|
|
||||||
class _ComfyType(ABC):
|
class _ComfyType(ABC):
|
||||||
Type = Any
|
Type = Any
|
||||||
io_type: str = None
|
io_type: str = None
|
||||||
@ -263,7 +269,7 @@ class Int(ComfyTypeIO):
|
|||||||
class Input(WidgetInput):
|
class Input(WidgetInput):
|
||||||
'''Integer input.'''
|
'''Integer input.'''
|
||||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
||||||
default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool=None,
|
default: int=None, min: int=None, max: int=None, step: int=None, control_after_generate: bool | ControlAfterGenerate=None,
|
||||||
display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None):
|
display_mode: NumberDisplay=None, socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None):
|
||||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced)
|
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced)
|
||||||
self.min = min
|
self.min = min
|
||||||
@ -345,7 +351,7 @@ class Combo(ComfyTypeIO):
|
|||||||
tooltip: str=None,
|
tooltip: str=None,
|
||||||
lazy: bool=None,
|
lazy: bool=None,
|
||||||
default: str | int | Enum = None,
|
default: str | int | Enum = None,
|
||||||
control_after_generate: bool=None,
|
control_after_generate: bool | ControlAfterGenerate=None,
|
||||||
upload: UploadType=None,
|
upload: UploadType=None,
|
||||||
image_folder: FolderType=None,
|
image_folder: FolderType=None,
|
||||||
remote: RemoteOptions=None,
|
remote: RemoteOptions=None,
|
||||||
@ -389,7 +395,7 @@ class MultiCombo(ComfyTypeI):
|
|||||||
Type = list[str]
|
Type = list[str]
|
||||||
class Input(Combo.Input):
|
class Input(Combo.Input):
|
||||||
def __init__(self, id: str, options: list[str], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
def __init__(self, id: str, options: list[str], display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
||||||
default: list[str]=None, placeholder: str=None, chip: bool=None, control_after_generate: bool=None,
|
default: list[str]=None, placeholder: str=None, chip: bool=None, control_after_generate: bool | ControlAfterGenerate=None,
|
||||||
socketless: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None):
|
socketless: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None):
|
||||||
super().__init__(id, options, display_name, optional, tooltip, lazy, default, control_after_generate, socketless=socketless, extra_dict=extra_dict, raw_link=raw_link, advanced=advanced)
|
super().__init__(id, options, display_name, optional, tooltip, lazy, default, control_after_generate, socketless=socketless, extra_dict=extra_dict, raw_link=raw_link, advanced=advanced)
|
||||||
self.multiselect = True
|
self.multiselect = True
|
||||||
@ -2097,6 +2103,7 @@ __all__ = [
|
|||||||
"UploadType",
|
"UploadType",
|
||||||
"RemoteOptions",
|
"RemoteOptions",
|
||||||
"NumberDisplay",
|
"NumberDisplay",
|
||||||
|
"ControlAfterGenerate",
|
||||||
|
|
||||||
"comfytype",
|
"comfytype",
|
||||||
"Custom",
|
"Custom",
|
||||||
|
|||||||
@ -198,11 +198,6 @@ dict_recraft_substyles_v3 = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class RecraftModel(str, Enum):
|
|
||||||
recraftv3 = 'recraftv3'
|
|
||||||
recraftv2 = 'recraftv2'
|
|
||||||
|
|
||||||
|
|
||||||
class RecraftImageSize(str, Enum):
|
class RecraftImageSize(str, Enum):
|
||||||
res_1024x1024 = '1024x1024'
|
res_1024x1024 = '1024x1024'
|
||||||
res_1365x1024 = '1365x1024'
|
res_1365x1024 = '1365x1024'
|
||||||
@ -221,6 +216,41 @@ class RecraftImageSize(str, Enum):
|
|||||||
res_1707x1024 = '1707x1024'
|
res_1707x1024 = '1707x1024'
|
||||||
|
|
||||||
|
|
||||||
|
RECRAFT_V4_SIZES = [
|
||||||
|
"1024x1024",
|
||||||
|
"1536x768",
|
||||||
|
"768x1536",
|
||||||
|
"1280x832",
|
||||||
|
"832x1280",
|
||||||
|
"1216x896",
|
||||||
|
"896x1216",
|
||||||
|
"1152x896",
|
||||||
|
"896x1152",
|
||||||
|
"832x1344",
|
||||||
|
"1280x896",
|
||||||
|
"896x1280",
|
||||||
|
"1344x768",
|
||||||
|
"768x1344",
|
||||||
|
]
|
||||||
|
|
||||||
|
RECRAFT_V4_PRO_SIZES = [
|
||||||
|
"2048x2048",
|
||||||
|
"3072x1536",
|
||||||
|
"1536x3072",
|
||||||
|
"2560x1664",
|
||||||
|
"1664x2560",
|
||||||
|
"2432x1792",
|
||||||
|
"1792x2432",
|
||||||
|
"2304x1792",
|
||||||
|
"1792x2304",
|
||||||
|
"1664x2688",
|
||||||
|
"1434x1024",
|
||||||
|
"1024x1434",
|
||||||
|
"2560x1792",
|
||||||
|
"1792x2560",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class RecraftColorObject(BaseModel):
|
class RecraftColorObject(BaseModel):
|
||||||
rgb: list[int] = Field(..., description='An array of 3 integer values in range of 0...255 defining RGB Color Model')
|
rgb: list[int] = Field(..., description='An array of 3 integer values in range of 0...255 defining RGB Color Model')
|
||||||
|
|
||||||
@ -234,17 +264,16 @@ class RecraftControlsObject(BaseModel):
|
|||||||
|
|
||||||
class RecraftImageGenerationRequest(BaseModel):
|
class RecraftImageGenerationRequest(BaseModel):
|
||||||
prompt: str = Field(..., description='The text prompt describing the image to generate')
|
prompt: str = Field(..., description='The text prompt describing the image to generate')
|
||||||
size: RecraftImageSize | None = Field(None, description='The size of the generated image (e.g., "1024x1024")')
|
size: str | None = Field(None, description='The size of the generated image (e.g., "1024x1024")')
|
||||||
n: int = Field(..., description='The number of images to generate')
|
n: int = Field(..., description='The number of images to generate')
|
||||||
negative_prompt: str | None = Field(None, description='A text description of undesired elements on an image')
|
negative_prompt: str | None = Field(None, description='A text description of undesired elements on an image')
|
||||||
model: RecraftModel | None = Field(RecraftModel.recraftv3, description='The model to use for generation (e.g., "recraftv3")')
|
model: str = Field(...)
|
||||||
style: str | None = Field(None, description='The style to apply to the generated image (e.g., "digital_illustration")')
|
style: str | None = Field(None, description='The style to apply to the generated image (e.g., "digital_illustration")')
|
||||||
substyle: str | None = Field(None, description='The substyle to apply to the generated image, depending on the style input')
|
substyle: str | None = Field(None, description='The substyle to apply to the generated image, depending on the style input')
|
||||||
controls: RecraftControlsObject | None = Field(None, description='A set of custom parameters to tweak generation process')
|
controls: RecraftControlsObject | None = Field(None, description='A set of custom parameters to tweak generation process')
|
||||||
style_id: str | None = Field(None, description='Use a previously uploaded style as a reference; UUID')
|
style_id: str | None = Field(None, description='Use a previously uploaded style as a reference; UUID')
|
||||||
strength: float | None = Field(None, description='Defines the difference with the original image, should lie in [0, 1], where 0 means almost identical, and 1 means miserable similarity')
|
strength: float | None = Field(None, description='Defines the difference with the original image, should lie in [0, 1], where 0 means almost identical, and 1 means miserable similarity')
|
||||||
random_seed: int | None = Field(None, description="Seed for video generation")
|
random_seed: int | None = Field(None, description="Seed for video generation")
|
||||||
# text_layout
|
|
||||||
|
|
||||||
|
|
||||||
class RecraftReturnedObject(BaseModel):
|
class RecraftReturnedObject(BaseModel):
|
||||||
|
|||||||
@ -1,5 +1,4 @@
|
|||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Optional, Union
|
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import torch
|
import torch
|
||||||
@ -9,6 +8,8 @@ from typing_extensions import override
|
|||||||
from comfy.utils import ProgressBar
|
from comfy.utils import ProgressBar
|
||||||
from comfy_api.latest import IO, ComfyExtension
|
from comfy_api.latest import IO, ComfyExtension
|
||||||
from comfy_api_nodes.apis.recraft import (
|
from comfy_api_nodes.apis.recraft import (
|
||||||
|
RECRAFT_V4_PRO_SIZES,
|
||||||
|
RECRAFT_V4_SIZES,
|
||||||
RecraftColor,
|
RecraftColor,
|
||||||
RecraftColorChain,
|
RecraftColorChain,
|
||||||
RecraftControls,
|
RecraftControls,
|
||||||
@ -18,7 +19,6 @@ from comfy_api_nodes.apis.recraft import (
|
|||||||
RecraftImageGenerationResponse,
|
RecraftImageGenerationResponse,
|
||||||
RecraftImageSize,
|
RecraftImageSize,
|
||||||
RecraftIO,
|
RecraftIO,
|
||||||
RecraftModel,
|
|
||||||
RecraftStyle,
|
RecraftStyle,
|
||||||
RecraftStyleV3,
|
RecraftStyleV3,
|
||||||
get_v3_substyles,
|
get_v3_substyles,
|
||||||
@ -39,7 +39,7 @@ async def handle_recraft_file_request(
|
|||||||
cls: type[IO.ComfyNode],
|
cls: type[IO.ComfyNode],
|
||||||
image: torch.Tensor,
|
image: torch.Tensor,
|
||||||
path: str,
|
path: str,
|
||||||
mask: Optional[torch.Tensor] = None,
|
mask: torch.Tensor | None = None,
|
||||||
total_pixels: int = 4096 * 4096,
|
total_pixels: int = 4096 * 4096,
|
||||||
timeout: int = 1024,
|
timeout: int = 1024,
|
||||||
request=None,
|
request=None,
|
||||||
@ -73,11 +73,11 @@ async def handle_recraft_file_request(
|
|||||||
def recraft_multipart_parser(
|
def recraft_multipart_parser(
|
||||||
data,
|
data,
|
||||||
parent_key=None,
|
parent_key=None,
|
||||||
formatter: Optional[type[callable]] = None,
|
formatter: type[callable] | None = None,
|
||||||
converted_to_check: Optional[list[list]] = None,
|
converted_to_check: list[list] | None = None,
|
||||||
is_list: bool = False,
|
is_list: bool = False,
|
||||||
return_mode: str = "formdata", # "dict" | "formdata"
|
return_mode: str = "formdata", # "dict" | "formdata"
|
||||||
) -> Union[dict, aiohttp.FormData]:
|
) -> dict | aiohttp.FormData:
|
||||||
"""
|
"""
|
||||||
Formats data such that multipart/form-data will work with aiohttp library when both files and data are present.
|
Formats data such that multipart/form-data will work with aiohttp library when both files and data are present.
|
||||||
|
|
||||||
@ -309,7 +309,7 @@ class RecraftStyleInfiniteStyleLibrary(IO.ComfyNode):
|
|||||||
node_id="RecraftStyleV3InfiniteStyleLibrary",
|
node_id="RecraftStyleV3InfiniteStyleLibrary",
|
||||||
display_name="Recraft Style - Infinite Style Library",
|
display_name="Recraft Style - Infinite Style Library",
|
||||||
category="api node/image/Recraft",
|
category="api node/image/Recraft",
|
||||||
description="Select style based on preexisting UUID from Recraft's Infinite Style Library.",
|
description="Choose style based on preexisting UUID from Recraft's Infinite Style Library.",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.String.Input("style_id", default="", tooltip="UUID of style from Infinite Style Library."),
|
IO.String.Input("style_id", default="", tooltip="UUID of style from Infinite Style Library."),
|
||||||
],
|
],
|
||||||
@ -485,7 +485,7 @@ class RecraftTextToImageNode(IO.ComfyNode):
|
|||||||
data=RecraftImageGenerationRequest(
|
data=RecraftImageGenerationRequest(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
negative_prompt=negative_prompt,
|
negative_prompt=negative_prompt,
|
||||||
model=RecraftModel.recraftv3,
|
model="recraftv3",
|
||||||
size=size,
|
size=size,
|
||||||
n=n,
|
n=n,
|
||||||
style=recraft_style.style,
|
style=recraft_style.style,
|
||||||
@ -598,7 +598,7 @@ class RecraftImageToImageNode(IO.ComfyNode):
|
|||||||
request = RecraftImageGenerationRequest(
|
request = RecraftImageGenerationRequest(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
negative_prompt=negative_prompt,
|
negative_prompt=negative_prompt,
|
||||||
model=RecraftModel.recraftv3,
|
model="recraftv3",
|
||||||
n=n,
|
n=n,
|
||||||
strength=round(strength, 2),
|
strength=round(strength, 2),
|
||||||
style=recraft_style.style,
|
style=recraft_style.style,
|
||||||
@ -698,7 +698,7 @@ class RecraftImageInpaintingNode(IO.ComfyNode):
|
|||||||
request = RecraftImageGenerationRequest(
|
request = RecraftImageGenerationRequest(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
negative_prompt=negative_prompt,
|
negative_prompt=negative_prompt,
|
||||||
model=RecraftModel.recraftv3,
|
model="recraftv3",
|
||||||
n=n,
|
n=n,
|
||||||
style=recraft_style.style,
|
style=recraft_style.style,
|
||||||
substyle=recraft_style.substyle,
|
substyle=recraft_style.substyle,
|
||||||
@ -810,7 +810,7 @@ class RecraftTextToVectorNode(IO.ComfyNode):
|
|||||||
data=RecraftImageGenerationRequest(
|
data=RecraftImageGenerationRequest(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
negative_prompt=negative_prompt,
|
negative_prompt=negative_prompt,
|
||||||
model=RecraftModel.recraftv3,
|
model="recraftv3",
|
||||||
size=size,
|
size=size,
|
||||||
n=n,
|
n=n,
|
||||||
style=recraft_style.style,
|
style=recraft_style.style,
|
||||||
@ -933,7 +933,7 @@ class RecraftReplaceBackgroundNode(IO.ComfyNode):
|
|||||||
request = RecraftImageGenerationRequest(
|
request = RecraftImageGenerationRequest(
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
negative_prompt=negative_prompt,
|
negative_prompt=negative_prompt,
|
||||||
model=RecraftModel.recraftv3,
|
model="recraftv3",
|
||||||
n=n,
|
n=n,
|
||||||
style=recraft_style.style,
|
style=recraft_style.style,
|
||||||
substyle=recraft_style.substyle,
|
substyle=recraft_style.substyle,
|
||||||
@ -1078,6 +1078,252 @@ class RecraftCreativeUpscaleNode(RecraftCrispUpscaleNode):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RecraftV4TextToImageNode(IO.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="RecraftV4TextToImageNode",
|
||||||
|
display_name="Recraft V4 Text to Image",
|
||||||
|
category="api node/image/Recraft",
|
||||||
|
description="Generates images using Recraft V4 or V4 Pro models.",
|
||||||
|
inputs=[
|
||||||
|
IO.String.Input(
|
||||||
|
"prompt",
|
||||||
|
multiline=True,
|
||||||
|
tooltip="Prompt for the image generation. Maximum 10,000 characters.",
|
||||||
|
),
|
||||||
|
IO.String.Input(
|
||||||
|
"negative_prompt",
|
||||||
|
multiline=True,
|
||||||
|
tooltip="An optional text description of undesired elements on an image.",
|
||||||
|
),
|
||||||
|
IO.DynamicCombo.Input(
|
||||||
|
"model",
|
||||||
|
options=[
|
||||||
|
IO.DynamicCombo.Option(
|
||||||
|
"recraftv4",
|
||||||
|
[
|
||||||
|
IO.Combo.Input(
|
||||||
|
"size",
|
||||||
|
options=RECRAFT_V4_SIZES,
|
||||||
|
default="1024x1024",
|
||||||
|
tooltip="The size of the generated image.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
IO.DynamicCombo.Option(
|
||||||
|
"recraftv4_pro",
|
||||||
|
[
|
||||||
|
IO.Combo.Input(
|
||||||
|
"size",
|
||||||
|
options=RECRAFT_V4_PRO_SIZES,
|
||||||
|
default="2048x2048",
|
||||||
|
tooltip="The size of the generated image.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
tooltip="The model to use for generation.",
|
||||||
|
),
|
||||||
|
IO.Int.Input(
|
||||||
|
"n",
|
||||||
|
default=1,
|
||||||
|
min=1,
|
||||||
|
max=6,
|
||||||
|
tooltip="The number of images to generate.",
|
||||||
|
),
|
||||||
|
IO.Int.Input(
|
||||||
|
"seed",
|
||||||
|
default=0,
|
||||||
|
min=0,
|
||||||
|
max=0xFFFFFFFFFFFFFFFF,
|
||||||
|
control_after_generate=True,
|
||||||
|
tooltip="Seed to determine if node should re-run; "
|
||||||
|
"actual results are nondeterministic regardless of seed.",
|
||||||
|
),
|
||||||
|
IO.Custom(RecraftIO.CONTROLS).Input(
|
||||||
|
"recraft_controls",
|
||||||
|
tooltip="Optional additional controls over the generation via the Recraft Controls node.",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.Image.Output(),
|
||||||
|
],
|
||||||
|
hidden=[
|
||||||
|
IO.Hidden.auth_token_comfy_org,
|
||||||
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
price_badge=IO.PriceBadge(
|
||||||
|
depends_on=IO.PriceBadgeDepends(widgets=["model", "n"]),
|
||||||
|
expr="""
|
||||||
|
(
|
||||||
|
$prices := {"recraftv4": 0.04, "recraftv4_pro": 0.25};
|
||||||
|
{"type":"usd","usd": $lookup($prices, widgets.model) * widgets.n}
|
||||||
|
)
|
||||||
|
""",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
prompt: str,
|
||||||
|
negative_prompt: str,
|
||||||
|
model: dict,
|
||||||
|
n: int,
|
||||||
|
seed: int,
|
||||||
|
recraft_controls: RecraftControls | None = None,
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
validate_string(prompt, strip_whitespace=False, min_length=1, max_length=10000)
|
||||||
|
response = await sync_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path="/proxy/recraft/image_generation", method="POST"),
|
||||||
|
response_model=RecraftImageGenerationResponse,
|
||||||
|
data=RecraftImageGenerationRequest(
|
||||||
|
prompt=prompt,
|
||||||
|
negative_prompt=negative_prompt if negative_prompt else None,
|
||||||
|
model=model["model"],
|
||||||
|
size=model["size"],
|
||||||
|
n=n,
|
||||||
|
controls=recraft_controls.create_api_model() if recraft_controls else None,
|
||||||
|
),
|
||||||
|
max_retries=1,
|
||||||
|
)
|
||||||
|
images = []
|
||||||
|
for data in response.data:
|
||||||
|
with handle_recraft_image_output():
|
||||||
|
image = bytesio_to_image_tensor(await download_url_as_bytesio(data.url, timeout=1024))
|
||||||
|
if len(image.shape) < 4:
|
||||||
|
image = image.unsqueeze(0)
|
||||||
|
images.append(image)
|
||||||
|
return IO.NodeOutput(torch.cat(images, dim=0))
|
||||||
|
|
||||||
|
|
||||||
|
class RecraftV4TextToVectorNode(IO.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="RecraftV4TextToVectorNode",
|
||||||
|
display_name="Recraft V4 Text to Vector",
|
||||||
|
category="api node/image/Recraft",
|
||||||
|
description="Generates SVG using Recraft V4 or V4 Pro models.",
|
||||||
|
inputs=[
|
||||||
|
IO.String.Input(
|
||||||
|
"prompt",
|
||||||
|
multiline=True,
|
||||||
|
tooltip="Prompt for the image generation. Maximum 10,000 characters.",
|
||||||
|
),
|
||||||
|
IO.String.Input(
|
||||||
|
"negative_prompt",
|
||||||
|
multiline=True,
|
||||||
|
tooltip="An optional text description of undesired elements on an image.",
|
||||||
|
),
|
||||||
|
IO.DynamicCombo.Input(
|
||||||
|
"model",
|
||||||
|
options=[
|
||||||
|
IO.DynamicCombo.Option(
|
||||||
|
"recraftv4",
|
||||||
|
[
|
||||||
|
IO.Combo.Input(
|
||||||
|
"size",
|
||||||
|
options=RECRAFT_V4_SIZES,
|
||||||
|
default="1024x1024",
|
||||||
|
tooltip="The size of the generated image.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
IO.DynamicCombo.Option(
|
||||||
|
"recraftv4_pro",
|
||||||
|
[
|
||||||
|
IO.Combo.Input(
|
||||||
|
"size",
|
||||||
|
options=RECRAFT_V4_PRO_SIZES,
|
||||||
|
default="2048x2048",
|
||||||
|
tooltip="The size of the generated image.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
tooltip="The model to use for generation.",
|
||||||
|
),
|
||||||
|
IO.Int.Input(
|
||||||
|
"n",
|
||||||
|
default=1,
|
||||||
|
min=1,
|
||||||
|
max=6,
|
||||||
|
tooltip="The number of images to generate.",
|
||||||
|
),
|
||||||
|
IO.Int.Input(
|
||||||
|
"seed",
|
||||||
|
default=0,
|
||||||
|
min=0,
|
||||||
|
max=0xFFFFFFFFFFFFFFFF,
|
||||||
|
control_after_generate=True,
|
||||||
|
tooltip="Seed to determine if node should re-run; "
|
||||||
|
"actual results are nondeterministic regardless of seed.",
|
||||||
|
),
|
||||||
|
IO.Custom(RecraftIO.CONTROLS).Input(
|
||||||
|
"recraft_controls",
|
||||||
|
tooltip="Optional additional controls over the generation via the Recraft Controls node.",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.SVG.Output(),
|
||||||
|
],
|
||||||
|
hidden=[
|
||||||
|
IO.Hidden.auth_token_comfy_org,
|
||||||
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
price_badge=IO.PriceBadge(
|
||||||
|
depends_on=IO.PriceBadgeDepends(widgets=["model", "n"]),
|
||||||
|
expr="""
|
||||||
|
(
|
||||||
|
$prices := {"recraftv4": 0.08, "recraftv4_pro": 0.30};
|
||||||
|
{"type":"usd","usd": $lookup($prices, widgets.model) * widgets.n}
|
||||||
|
)
|
||||||
|
""",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
prompt: str,
|
||||||
|
negative_prompt: str,
|
||||||
|
model: dict,
|
||||||
|
n: int,
|
||||||
|
seed: int,
|
||||||
|
recraft_controls: RecraftControls | None = None,
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
validate_string(prompt, strip_whitespace=False, min_length=1, max_length=10000)
|
||||||
|
response = await sync_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path="/proxy/recraft/image_generation", method="POST"),
|
||||||
|
response_model=RecraftImageGenerationResponse,
|
||||||
|
data=RecraftImageGenerationRequest(
|
||||||
|
prompt=prompt,
|
||||||
|
negative_prompt=negative_prompt if negative_prompt else None,
|
||||||
|
model=model["model"],
|
||||||
|
size=model["size"],
|
||||||
|
n=n,
|
||||||
|
style="vector_illustration",
|
||||||
|
substyle=None,
|
||||||
|
controls=recraft_controls.create_api_model() if recraft_controls else None,
|
||||||
|
),
|
||||||
|
max_retries=1,
|
||||||
|
)
|
||||||
|
svg_data = []
|
||||||
|
for data in response.data:
|
||||||
|
svg_data.append(await download_url_as_bytesio(data.url, timeout=1024))
|
||||||
|
return IO.NodeOutput(SVG(svg_data))
|
||||||
|
|
||||||
|
|
||||||
class RecraftExtension(ComfyExtension):
|
class RecraftExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||||
@ -1098,6 +1344,8 @@ class RecraftExtension(ComfyExtension):
|
|||||||
RecraftCreateStyleNode,
|
RecraftCreateStyleNode,
|
||||||
RecraftColorRGBNode,
|
RecraftColorRGBNode,
|
||||||
RecraftControlsNode,
|
RecraftControlsNode,
|
||||||
|
RecraftV4TextToImageNode,
|
||||||
|
RecraftV4TextToVectorNode,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -54,6 +54,7 @@ async def execute_task(
|
|||||||
response_model=TaskStatusResponse,
|
response_model=TaskStatusResponse,
|
||||||
status_extractor=lambda r: r.state,
|
status_extractor=lambda r: r.state,
|
||||||
progress_extractor=lambda r: r.progress,
|
progress_extractor=lambda r: r.progress,
|
||||||
|
price_extractor=lambda r: r.credits * 0.005 if r.credits is not None else None,
|
||||||
max_poll_attempts=max_poll_attempts,
|
max_poll_attempts=max_poll_attempts,
|
||||||
)
|
)
|
||||||
if not response.creations:
|
if not response.creations:
|
||||||
@ -1306,6 +1307,36 @@ class Vidu3TextToVideoNode(IO.ComfyNode):
|
|||||||
),
|
),
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
|
IO.DynamicCombo.Option(
|
||||||
|
"viduq3-turbo",
|
||||||
|
[
|
||||||
|
IO.Combo.Input(
|
||||||
|
"aspect_ratio",
|
||||||
|
options=["16:9", "9:16", "3:4", "4:3", "1:1"],
|
||||||
|
tooltip="The aspect ratio of the output video.",
|
||||||
|
),
|
||||||
|
IO.Combo.Input(
|
||||||
|
"resolution",
|
||||||
|
options=["720p", "1080p"],
|
||||||
|
tooltip="Resolution of the output video.",
|
||||||
|
),
|
||||||
|
IO.Int.Input(
|
||||||
|
"duration",
|
||||||
|
default=5,
|
||||||
|
min=1,
|
||||||
|
max=16,
|
||||||
|
step=1,
|
||||||
|
display_mode=IO.NumberDisplay.slider,
|
||||||
|
tooltip="Duration of the output video in seconds.",
|
||||||
|
),
|
||||||
|
IO.Boolean.Input(
|
||||||
|
"audio",
|
||||||
|
default=False,
|
||||||
|
tooltip="When enabled, outputs video with sound "
|
||||||
|
"(including dialogue and sound effects).",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
],
|
],
|
||||||
tooltip="Model to use for video generation.",
|
tooltip="Model to use for video generation.",
|
||||||
),
|
),
|
||||||
@ -1334,13 +1365,20 @@ class Vidu3TextToVideoNode(IO.ComfyNode):
|
|||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
price_badge=IO.PriceBadge(
|
price_badge=IO.PriceBadge(
|
||||||
depends_on=IO.PriceBadgeDepends(widgets=["model.duration", "model.resolution"]),
|
depends_on=IO.PriceBadgeDepends(widgets=["model", "model.duration", "model.resolution"]),
|
||||||
expr="""
|
expr="""
|
||||||
(
|
(
|
||||||
$res := $lookup(widgets, "model.resolution");
|
$res := $lookup(widgets, "model.resolution");
|
||||||
$base := $lookup({"720p": 0.075, "1080p": 0.1}, $res);
|
$d := $lookup(widgets, "model.duration");
|
||||||
$perSec := $lookup({"720p": 0.025, "1080p": 0.05}, $res);
|
$contains(widgets.model, "turbo")
|
||||||
{"type":"usd","usd": $base + $perSec * ($lookup(widgets, "model.duration") - 1)}
|
? (
|
||||||
|
$rate := $lookup({"720p": 0.06, "1080p": 0.08}, $res);
|
||||||
|
{"type":"usd","usd": $rate * $d}
|
||||||
|
)
|
||||||
|
: (
|
||||||
|
$rate := $lookup({"720p": 0.15, "1080p": 0.16}, $res);
|
||||||
|
{"type":"usd","usd": $rate * $d}
|
||||||
|
)
|
||||||
)
|
)
|
||||||
""",
|
""",
|
||||||
),
|
),
|
||||||
@ -1409,6 +1447,31 @@ class Vidu3ImageToVideoNode(IO.ComfyNode):
|
|||||||
),
|
),
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
|
IO.DynamicCombo.Option(
|
||||||
|
"viduq3-turbo",
|
||||||
|
[
|
||||||
|
IO.Combo.Input(
|
||||||
|
"resolution",
|
||||||
|
options=["720p", "1080p"],
|
||||||
|
tooltip="Resolution of the output video.",
|
||||||
|
),
|
||||||
|
IO.Int.Input(
|
||||||
|
"duration",
|
||||||
|
default=5,
|
||||||
|
min=1,
|
||||||
|
max=16,
|
||||||
|
step=1,
|
||||||
|
display_mode=IO.NumberDisplay.slider,
|
||||||
|
tooltip="Duration of the output video in seconds.",
|
||||||
|
),
|
||||||
|
IO.Boolean.Input(
|
||||||
|
"audio",
|
||||||
|
default=False,
|
||||||
|
tooltip="When enabled, outputs video with sound "
|
||||||
|
"(including dialogue and sound effects).",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
],
|
],
|
||||||
tooltip="Model to use for video generation.",
|
tooltip="Model to use for video generation.",
|
||||||
),
|
),
|
||||||
@ -1442,13 +1505,20 @@ class Vidu3ImageToVideoNode(IO.ComfyNode):
|
|||||||
],
|
],
|
||||||
is_api_node=True,
|
is_api_node=True,
|
||||||
price_badge=IO.PriceBadge(
|
price_badge=IO.PriceBadge(
|
||||||
depends_on=IO.PriceBadgeDepends(widgets=["model.duration", "model.resolution"]),
|
depends_on=IO.PriceBadgeDepends(widgets=["model", "model.duration", "model.resolution"]),
|
||||||
expr="""
|
expr="""
|
||||||
(
|
(
|
||||||
$res := $lookup(widgets, "model.resolution");
|
$res := $lookup(widgets, "model.resolution");
|
||||||
$base := $lookup({"720p": 0.075, "1080p": 0.275, "2k": 0.35}, $res);
|
$d := $lookup(widgets, "model.duration");
|
||||||
$perSec := $lookup({"720p": 0.05, "1080p": 0.075, "2k": 0.075}, $res);
|
$contains(widgets.model, "turbo")
|
||||||
{"type":"usd","usd": $base + $perSec * ($lookup(widgets, "model.duration") - 1)}
|
? (
|
||||||
|
$rate := $lookup({"720p": 0.06, "1080p": 0.08}, $res);
|
||||||
|
{"type":"usd","usd": $rate * $d}
|
||||||
|
)
|
||||||
|
: (
|
||||||
|
$rate := $lookup({"720p": 0.15, "1080p": 0.16, "2k": 0.2}, $res);
|
||||||
|
{"type":"usd","usd": $rate * $d}
|
||||||
|
)
|
||||||
)
|
)
|
||||||
""",
|
""",
|
||||||
),
|
),
|
||||||
@ -1481,6 +1551,145 @@ class Vidu3ImageToVideoNode(IO.ComfyNode):
|
|||||||
return IO.NodeOutput(await download_url_to_video_output(results[0].url))
|
return IO.NodeOutput(await download_url_to_video_output(results[0].url))
|
||||||
|
|
||||||
|
|
||||||
|
class Vidu3StartEndToVideoNode(IO.ComfyNode):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="Vidu3StartEndToVideoNode",
|
||||||
|
display_name="Vidu Q3 Start/End Frame-to-Video Generation",
|
||||||
|
category="api node/video/Vidu",
|
||||||
|
description="Generate a video from a start frame, an end frame, and a prompt.",
|
||||||
|
inputs=[
|
||||||
|
IO.DynamicCombo.Input(
|
||||||
|
"model",
|
||||||
|
options=[
|
||||||
|
IO.DynamicCombo.Option(
|
||||||
|
"viduq3-pro",
|
||||||
|
[
|
||||||
|
IO.Combo.Input(
|
||||||
|
"resolution",
|
||||||
|
options=["720p", "1080p"],
|
||||||
|
tooltip="Resolution of the output video.",
|
||||||
|
),
|
||||||
|
IO.Int.Input(
|
||||||
|
"duration",
|
||||||
|
default=5,
|
||||||
|
min=1,
|
||||||
|
max=16,
|
||||||
|
step=1,
|
||||||
|
display_mode=IO.NumberDisplay.slider,
|
||||||
|
tooltip="Duration of the output video in seconds.",
|
||||||
|
),
|
||||||
|
IO.Boolean.Input(
|
||||||
|
"audio",
|
||||||
|
default=False,
|
||||||
|
tooltip="When enabled, outputs video with sound "
|
||||||
|
"(including dialogue and sound effects).",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
IO.DynamicCombo.Option(
|
||||||
|
"viduq3-turbo",
|
||||||
|
[
|
||||||
|
IO.Combo.Input(
|
||||||
|
"resolution",
|
||||||
|
options=["720p", "1080p"],
|
||||||
|
tooltip="Resolution of the output video.",
|
||||||
|
),
|
||||||
|
IO.Int.Input(
|
||||||
|
"duration",
|
||||||
|
default=5,
|
||||||
|
min=1,
|
||||||
|
max=16,
|
||||||
|
step=1,
|
||||||
|
display_mode=IO.NumberDisplay.slider,
|
||||||
|
tooltip="Duration of the output video in seconds.",
|
||||||
|
),
|
||||||
|
IO.Boolean.Input(
|
||||||
|
"audio",
|
||||||
|
default=False,
|
||||||
|
tooltip="When enabled, outputs video with sound "
|
||||||
|
"(including dialogue and sound effects).",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
tooltip="Model to use for video generation.",
|
||||||
|
),
|
||||||
|
IO.Image.Input("first_frame"),
|
||||||
|
IO.Image.Input("end_frame"),
|
||||||
|
IO.String.Input(
|
||||||
|
"prompt",
|
||||||
|
multiline=True,
|
||||||
|
tooltip="Prompt description (max 2000 characters).",
|
||||||
|
),
|
||||||
|
IO.Int.Input(
|
||||||
|
"seed",
|
||||||
|
default=1,
|
||||||
|
min=0,
|
||||||
|
max=2147483647,
|
||||||
|
step=1,
|
||||||
|
display_mode=IO.NumberDisplay.number,
|
||||||
|
control_after_generate=True,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.Video.Output(),
|
||||||
|
],
|
||||||
|
hidden=[
|
||||||
|
IO.Hidden.auth_token_comfy_org,
|
||||||
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
price_badge=IO.PriceBadge(
|
||||||
|
depends_on=IO.PriceBadgeDepends(widgets=["model", "model.duration", "model.resolution"]),
|
||||||
|
expr="""
|
||||||
|
(
|
||||||
|
$res := $lookup(widgets, "model.resolution");
|
||||||
|
$d := $lookup(widgets, "model.duration");
|
||||||
|
$contains(widgets.model, "turbo")
|
||||||
|
? (
|
||||||
|
$rate := $lookup({"720p": 0.06, "1080p": 0.08}, $res);
|
||||||
|
{"type":"usd","usd": $rate * $d}
|
||||||
|
)
|
||||||
|
: (
|
||||||
|
$rate := $lookup({"720p": 0.15, "1080p": 0.16}, $res);
|
||||||
|
{"type":"usd","usd": $rate * $d}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
""",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
model: dict,
|
||||||
|
first_frame: Input.Image,
|
||||||
|
end_frame: Input.Image,
|
||||||
|
prompt: str,
|
||||||
|
seed: int,
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
validate_string(prompt, max_length=2000)
|
||||||
|
validate_images_aspect_ratio_closeness(first_frame, end_frame, min_rel=0.8, max_rel=1.25, strict=False)
|
||||||
|
payload = TaskCreationRequest(
|
||||||
|
model=model["model"],
|
||||||
|
prompt=prompt,
|
||||||
|
duration=model["duration"],
|
||||||
|
seed=seed,
|
||||||
|
resolution=model["resolution"],
|
||||||
|
audio=model["audio"],
|
||||||
|
images=[
|
||||||
|
(await upload_images_to_comfyapi(cls, frame, max_images=1, mime_type="image/png"))[0]
|
||||||
|
for frame in (first_frame, end_frame)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
results = await execute_task(cls, VIDU_START_END_VIDEO, payload)
|
||||||
|
return IO.NodeOutput(await download_url_to_video_output(results[0].url))
|
||||||
|
|
||||||
|
|
||||||
class ViduExtension(ComfyExtension):
|
class ViduExtension(ComfyExtension):
|
||||||
@override
|
@override
|
||||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||||
@ -1497,6 +1706,7 @@ class ViduExtension(ComfyExtension):
|
|||||||
ViduMultiFrameVideoNode,
|
ViduMultiFrameVideoNode,
|
||||||
Vidu3TextToVideoNode,
|
Vidu3TextToVideoNode,
|
||||||
Vidu3ImageToVideoNode,
|
Vidu3ImageToVideoNode,
|
||||||
|
Vidu3StartEndToVideoNode,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
99
comfy_extras/nodes_nag.py
Normal file
99
comfy_extras/nodes_nag.py
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
import torch
|
||||||
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
|
|
||||||
|
class NAGuidance(io.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls) -> io.Schema:
|
||||||
|
return io.Schema(
|
||||||
|
node_id="NAGuidance",
|
||||||
|
display_name="Normalized Attention Guidance",
|
||||||
|
description="Applies Normalized Attention Guidance to models, enabling negative prompts on distilled/schnell models.",
|
||||||
|
category="",
|
||||||
|
is_experimental=True,
|
||||||
|
inputs=[
|
||||||
|
io.Model.Input("model", tooltip="The model to apply NAG to."),
|
||||||
|
io.Float.Input("nag_scale", min=0.0, default=5.0, max=50.0, step=0.1, tooltip="The guidance scale factor. Higher values push further from the negative prompt."),
|
||||||
|
io.Float.Input("nag_alpha", min=0.0, default=0.5, max=1.0, step=0.01, tooltip="Blending factor for the normalized attention. 1.0 is full replacement, 0.0 is no effect."),
|
||||||
|
io.Float.Input("nag_tau", min=1.0, default=1.5, max=10.0, step=0.01),
|
||||||
|
# io.Float.Input("start_percent", min=0.0, default=0.0, max=1.0, step=0.01, tooltip="The relative sampling step to begin applying NAG."),
|
||||||
|
# io.Float.Input("end_percent", min=0.0, default=1.0, max=1.0, step=0.01, tooltip="The relative sampling step to stop applying NAG."),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
io.Model.Output(tooltip="The patched model with NAG enabled."),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def execute(cls, model: io.Model.Type, nag_scale: float, nag_alpha: float, nag_tau: float) -> io.NodeOutput:
|
||||||
|
m = model.clone()
|
||||||
|
|
||||||
|
# sigma_start = m.get_model_object("model_sampling").percent_to_sigma(start_percent)
|
||||||
|
# sigma_end = m.get_model_object("model_sampling").percent_to_sigma(end_percent)
|
||||||
|
|
||||||
|
def nag_attention_output_patch(out, extra_options):
|
||||||
|
cond_or_uncond = extra_options.get("cond_or_uncond", None)
|
||||||
|
if cond_or_uncond is None:
|
||||||
|
return out
|
||||||
|
|
||||||
|
if not (1 in cond_or_uncond and 0 in cond_or_uncond):
|
||||||
|
return out
|
||||||
|
|
||||||
|
# sigma = extra_options.get("sigmas", None)
|
||||||
|
# if sigma is not None and len(sigma) > 0:
|
||||||
|
# sigma = sigma[0].item()
|
||||||
|
# if sigma > sigma_start or sigma < sigma_end:
|
||||||
|
# return out
|
||||||
|
|
||||||
|
img_slice = extra_options.get("img_slice", None)
|
||||||
|
|
||||||
|
if img_slice is not None:
|
||||||
|
orig_out = out
|
||||||
|
out = out[:, img_slice[0]:img_slice[1]] # only apply on img part
|
||||||
|
|
||||||
|
batch_size = out.shape[0]
|
||||||
|
half_size = batch_size // len(cond_or_uncond)
|
||||||
|
|
||||||
|
ind_neg = cond_or_uncond.index(1)
|
||||||
|
ind_pos = cond_or_uncond.index(0)
|
||||||
|
z_pos = out[half_size * ind_pos:half_size * (ind_pos + 1)]
|
||||||
|
z_neg = out[half_size * ind_neg:half_size * (ind_neg + 1)]
|
||||||
|
|
||||||
|
guided = z_pos * nag_scale - z_neg * (nag_scale - 1.0)
|
||||||
|
|
||||||
|
eps = 1e-6
|
||||||
|
norm_pos = torch.norm(z_pos, p=1, dim=-1, keepdim=True).clamp_min(eps)
|
||||||
|
norm_guided = torch.norm(guided, p=1, dim=-1, keepdim=True).clamp_min(eps)
|
||||||
|
|
||||||
|
ratio = norm_guided / norm_pos
|
||||||
|
scale_factor = torch.minimum(ratio, torch.full_like(ratio, nag_tau)) / ratio
|
||||||
|
|
||||||
|
guided_normalized = guided * scale_factor
|
||||||
|
|
||||||
|
z_final = guided_normalized * nag_alpha + z_pos * (1.0 - nag_alpha)
|
||||||
|
|
||||||
|
if img_slice is not None:
|
||||||
|
orig_out[half_size * ind_neg:half_size * (ind_neg + 1), img_slice[0]:img_slice[1]] = z_final
|
||||||
|
orig_out[half_size * ind_pos:half_size * (ind_pos + 1), img_slice[0]:img_slice[1]] = z_final
|
||||||
|
return orig_out
|
||||||
|
else:
|
||||||
|
out[half_size * ind_pos:half_size * (ind_pos + 1)] = z_final
|
||||||
|
return out
|
||||||
|
|
||||||
|
m.set_model_attn1_output_patch(nag_attention_output_patch)
|
||||||
|
m.disable_model_cfg1_optimization()
|
||||||
|
|
||||||
|
return io.NodeOutput(m)
|
||||||
|
|
||||||
|
|
||||||
|
class NagExtension(ComfyExtension):
|
||||||
|
@override
|
||||||
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
|
return [
|
||||||
|
NAGuidance,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> NagExtension:
|
||||||
|
return NagExtension()
|
||||||
@ -1,3 +1,3 @@
|
|||||||
# This file is automatically generated by the build process when version is
|
# This file is automatically generated by the build process when version is
|
||||||
# updated in pyproject.toml.
|
# updated in pyproject.toml.
|
||||||
__version__ = "0.13.0"
|
__version__ = "0.14.1"
|
||||||
|
|||||||
1
nodes.py
1
nodes.py
@ -2437,6 +2437,7 @@ async def init_builtin_extra_nodes():
|
|||||||
"nodes_color.py",
|
"nodes_color.py",
|
||||||
"nodes_toolkit.py",
|
"nodes_toolkit.py",
|
||||||
"nodes_replacements.py",
|
"nodes_replacements.py",
|
||||||
|
"nodes_nag.py",
|
||||||
]
|
]
|
||||||
|
|
||||||
import_failed = []
|
import_failed = []
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "ComfyUI"
|
name = "ComfyUI"
|
||||||
version = "0.13.0"
|
version = "0.14.1"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = { file = "LICENSE" }
|
license = { file = "LICENSE" }
|
||||||
requires-python = ">=3.10"
|
requires-python = ">=3.10"
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
comfyui-frontend-package==1.38.14
|
comfyui-frontend-package==1.39.14
|
||||||
comfyui-workflow-templates==0.8.42
|
comfyui-workflow-templates==0.8.43
|
||||||
comfyui-embedded-docs==0.4.1
|
comfyui-embedded-docs==0.4.1
|
||||||
torch
|
torch
|
||||||
torchsde
|
torchsde
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user