ComfyUI/comfy_extras/nodes_qwen.py

294 lines
13 KiB
Python

import node_helpers
import comfy.utils
import comfy.conds
import math
import torch
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
class TextEncodeQwenImageEdit(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="TextEncodeQwenImageEdit",
category="advanced/conditioning",
inputs=[
io.Clip.Input("clip"),
io.String.Input("prompt", multiline=True, dynamic_prompts=True),
io.Vae.Input("vae", optional=True),
io.Image.Input("image", optional=True),
],
outputs=[
io.Conditioning.Output(),
],
)
@classmethod
def execute(cls, clip, prompt, vae=None, image=None) -> io.NodeOutput:
ref_latent = None
if image is None:
images = []
else:
samples = image.movedim(-1, 1)
total = int(1024 * 1024)
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
width = round(samples.shape[3] * scale_by)
height = round(samples.shape[2] * scale_by)
s = comfy.utils.common_upscale(samples, width, height, "area", "disabled")
image = s.movedim(1, -1)
images = [image[:, :, :, :3]]
if vae is not None:
ref_latent = vae.encode(image[:, :, :, :3])
tokens = clip.tokenize(prompt, images=images)
conditioning = clip.encode_from_tokens_scheduled(tokens)
if ref_latent is not None:
conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": [ref_latent]}, append=True)
return io.NodeOutput(conditioning)
class TextEncodeQwenImageEditPlus(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="TextEncodeQwenImageEditPlus",
category="advanced/conditioning",
inputs=[
io.Clip.Input("clip"),
io.String.Input("prompt", multiline=True, dynamic_prompts=True),
io.Vae.Input("vae", optional=True),
io.Image.Input("image1", optional=True),
io.Image.Input("image2", optional=True),
io.Image.Input("image3", optional=True),
],
outputs=[
io.Conditioning.Output(),
],
)
@classmethod
def execute(cls, clip, prompt, vae=None, image1=None, image2=None, image3=None) -> io.NodeOutput:
ref_latents = []
images = [image1, image2, image3]
images_vl = []
llama_template = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
image_prompt = ""
for i, image in enumerate(images):
if image is not None:
samples = image.movedim(-1, 1)
total = int(384 * 384)
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
width = round(samples.shape[3] * scale_by)
height = round(samples.shape[2] * scale_by)
s = comfy.utils.common_upscale(samples, width, height, "area", "disabled")
images_vl.append(s.movedim(1, -1))
if vae is not None:
total = int(1024 * 1024)
scale_by = math.sqrt(total / (samples.shape[3] * samples.shape[2]))
width = round(samples.shape[3] * scale_by / 8.0) * 8
height = round(samples.shape[2] * scale_by / 8.0) * 8
s = comfy.utils.common_upscale(samples, width, height, "area", "disabled")
ref_latents.append(vae.encode(s.movedim(1, -1)[:, :, :, :3]))
image_prompt += "Picture {}: <|vision_start|><|image_pad|><|vision_end|>".format(i + 1)
tokens = clip.tokenize(image_prompt + prompt, images=images_vl, llama_template=llama_template)
conditioning = clip.encode_from_tokens_scheduled(tokens)
if len(ref_latents) > 0:
conditioning = node_helpers.conditioning_set_values(conditioning, {"reference_latents": ref_latents}, append=True)
return io.NodeOutput(conditioning)
class TextEncodeQwenImageEliGen(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="TextEncodeQwenImageEliGen",
category="advanced/conditioning",
inputs=[
io.Clip.Input("clip"),
io.Conditioning.Input("global_conditioning"),
io.Latent.Input("latent"),
io.Image.Input("entity_mask_1", optional=True),
io.String.Input("entity_prompt_1", multiline=True, dynamic_prompts=True, default=""),
io.Image.Input("entity_mask_2", optional=True),
io.String.Input("entity_prompt_2", multiline=True, dynamic_prompts=True, default=""),
io.Image.Input("entity_mask_3", optional=True),
io.String.Input("entity_prompt_3", multiline=True, dynamic_prompts=True, default=""),
],
outputs=[
io.Conditioning.Output(),
],
)
@classmethod
def execute(cls, clip, global_conditioning, latent, entity_prompt_1="", entity_mask_1=None,
entity_prompt_2="", entity_mask_2=None, entity_prompt_3="", entity_mask_3=None) -> io.NodeOutput:
# Extract dimensions from latent tensor
# latent["samples"] shape: [batch, channels, latent_h, latent_w]
latent_samples = latent["samples"]
unpadded_latent_height = latent_samples.shape[2] # Unpadded latent space
unpadded_latent_width = latent_samples.shape[3] # Unpadded latent space
# Calculate padded dimensions (same logic as model's pad_to_patch_size with patch_size=2)
# The model pads latents to be multiples of patch_size (2 for Qwen)
patch_size = 2
pad_h = (patch_size - unpadded_latent_height % patch_size) % patch_size
pad_w = (patch_size - unpadded_latent_width % patch_size) % patch_size
latent_height = unpadded_latent_height + pad_h # Padded latent dimensions
latent_width = unpadded_latent_width + pad_w # Padded latent dimensions
height = latent_height * 8 # Convert to pixel space for logging
width = latent_width * 8
if pad_h > 0 or pad_w > 0:
print(f"[EliGen] Latent padding detected: {unpadded_latent_height}x{unpadded_latent_width}{latent_height}x{latent_width}")
print(f"[EliGen] Target generation dimensions: {height}x{width} pixels ({latent_height}x{latent_width} latent)")
# Collect entity prompts and masks
entity_prompts = [entity_prompt_1, entity_prompt_2, entity_prompt_3]
entity_masks_raw = [entity_mask_1, entity_mask_2, entity_mask_3]
# Filter out entities with empty prompts or missing masks
valid_entities = []
for prompt, mask in zip(entity_prompts, entity_masks_raw):
if prompt.strip() and mask is not None:
valid_entities.append((prompt, mask))
# Log warning if some entities were skipped
total_prompts_provided = len([p for p in entity_prompts if p.strip()])
if len(valid_entities) < total_prompts_provided:
print(f"[EliGen] Warning: Only {len(valid_entities)} of {total_prompts_provided} entity prompts have valid masks")
# If no valid entities, return standard conditioning
if len(valid_entities) == 0:
return io.NodeOutput(global_conditioning)
# Encode each entity prompt separately
entity_prompt_emb_list = []
entity_prompt_emb_mask_list = []
for entity_prompt, _ in valid_entities:
entity_tokens = clip.tokenize(entity_prompt)
entity_cond = clip.encode_from_tokens_scheduled(entity_tokens)
# Extract embeddings and masks from conditioning
# Conditioning format: [[cond_tensor, extra_dict], ...]
entity_prompt_emb = entity_cond[0][0] # The embedding tensor directly [1, seq_len, 3584]
extra_dict = entity_cond[0][1] # Metadata dict (pooled_output, attention_mask, etc.)
# Extract attention mask from metadata dict
entity_prompt_emb_mask = extra_dict.get("attention_mask", None)
# If no attention mask in extra_dict, create one (all True)
if entity_prompt_emb_mask is None:
seq_len = entity_prompt_emb.shape[1]
entity_prompt_emb_mask = torch.ones((entity_prompt_emb.shape[0], seq_len),
dtype=torch.bool, device=entity_prompt_emb.device)
entity_prompt_emb_list.append(entity_prompt_emb)
entity_prompt_emb_mask_list.append(entity_prompt_emb_mask)
# Process spatial masks to latent space
processed_masks = []
for i, (_, mask) in enumerate(valid_entities):
# mask is expected to be [batch, height, width, channels] or [batch, height, width]
mask_tensor = mask
# Log original mask dimensions
original_shape = mask_tensor.shape
if len(original_shape) == 3:
orig_h, orig_w = original_shape[0], original_shape[1]
elif len(original_shape) == 4:
orig_h, orig_w = original_shape[1], original_shape[2]
else:
orig_h, orig_w = original_shape[-2], original_shape[-1]
print(f"[EliGen] Entity {i+1} mask: {orig_h}x{orig_w} → will resize to {latent_height}x{latent_width} latent")
# Ensure mask is in [batch, channels, height, width] format for upscale
if len(mask_tensor.shape) == 3:
# [height, width, channels] -> [1, height, width, channels] (add batch dimension)
mask_tensor = mask_tensor.unsqueeze(0)
elif len(mask_tensor.shape) == 4 and mask_tensor.shape[-1] in [1, 3, 4]:
# [batch, height, width, channels] -> [batch, channels, height, width]
mask_tensor = mask_tensor.movedim(-1, 1)
# Take only first channel if multiple channels
if mask_tensor.shape[1] > 1:
mask_tensor = mask_tensor[:, 0:1, :, :]
# Resize to latent space dimensions using nearest neighbor
resized_mask = comfy.utils.common_upscale(
mask_tensor,
latent_width,
latent_height,
upscale_method="nearest-exact",
crop="disabled"
)
# Threshold to binary (0 or 1)
# Use > 0 instead of > 0.5 to preserve edge pixels from nearest-neighbor downsampling
resized_mask = (resized_mask > 0).float()
# Log how many pixels are active in the mask
active_pixels = (resized_mask > 0).sum().item()
total_pixels = resized_mask.numel()
print(f"[EliGen] Entity {i+1} mask coverage: {active_pixels}/{total_pixels} pixels ({100*active_pixels/total_pixels:.1f}%)")
processed_masks.append(resized_mask)
# Stack masks: [batch, num_entities, 1, latent_height, latent_width]
# No padding - handle dynamic number of entities
entity_masks_tensor = torch.stack(processed_masks, dim=1)
# Extract global prompt embedding and mask from conditioning
# Conditioning format: [[cond_tensor, extra_dict]]
global_prompt_emb = global_conditioning[0][0] # The embedding tensor directly
global_extra_dict = global_conditioning[0][1] # Metadata dict
global_prompt_emb_mask = global_extra_dict.get("attention_mask", None)
# If no attention mask, create one (all True)
if global_prompt_emb_mask is None:
global_prompt_emb_mask = torch.ones((global_prompt_emb.shape[0], global_prompt_emb.shape[1]),
dtype=torch.bool, device=global_prompt_emb.device)
# Attach entity data to conditioning using conditioning_set_values
# Wrap lists in CONDList so they can be properly concatenated during CFG
entity_data = {
"entity_prompt_emb": comfy.conds.CONDList(entity_prompt_emb_list),
"entity_prompt_emb_mask": comfy.conds.CONDList(entity_prompt_emb_mask_list),
"entity_masks": comfy.conds.CONDRegular(entity_masks_tensor),
}
conditioning_with_entities = node_helpers.conditioning_set_values(
global_conditioning,
entity_data,
append=True
)
return io.NodeOutput(conditioning_with_entities)
class QwenExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [
TextEncodeQwenImageEdit,
TextEncodeQwenImageEditPlus,
TextEncodeQwenImageEliGen,
]
async def comfy_entrypoint() -> QwenExtension:
return QwenExtension()