# generate_text.py
import os
import sys
import json
import base64
import asyncio
import logging
from typing import Dict, Any, List, Optional, Union, Tuple, AsyncGenerator
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# Add parent directory to path to ensure imports work correctly
current_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.dirname(current_dir)
if parent_dir not in sys.path:
sys.path.insert(0, parent_dir)
# Try to import ComfyUI-specific modules
try:
import folder_paths
from server import PromptServer # <-- Import PromptServer
except ImportError:
logger.warning("Could not import folder_paths or server. Make sure ComfyUI environment is set up.")
folder_paths = None
PromptServer = None
# Check for required dependencies
missing_deps = []
try:
import aiohttp
except ImportError:
missing_deps.append("aiohttp")
if missing_deps:
logger.warning(f"Missing dependencies: {', '.join(missing_deps)}. Some functionality may not work.")
logger.warning("Please install missing dependencies: pip install " + " ".join(missing_deps))
# Import utility functions (assuming they exist)
from send_request import send_request, run_async, is_gpt5_model # Keep non-streaming version if needed
from api.openai_api import send_openai_responses_stream
from api.gemini_api import send_gemini_request_stream
from llmtoolkit_utils import query_local_ollama_models, ensure_ollama_server, ensure_ollama_model, get_api_key
# Local transformers streaming (optional) - removed unused import
# The send_transformers_request_stream function was imported but never used
send_transformers_request_stream = None # type: ignore
# Payload helper to embed context into a string subclass
from context_payload import ContextPayload
# -----------------------------------------------------------------------------
# Helpers: Lazy OpenCV import and video -> base64 frame extraction
# -----------------------------------------------------------------------------
_cv2_ref = None
def _get_cv2():
global _cv2_ref
if _cv2_ref is None:
try:
import cv2 as _cv2 # type: ignore
_cv2_ref = _cv2
except Exception:
_cv2_ref = None
logger.warning("cv2 not available. Video file frame extraction disabled.")
return _cv2_ref
def _is_video_file(path: str) -> bool:
try:
ext = os.path.splitext(path)[1].lower()
except Exception:
return False
return ext in {".mp4", ".mov", ".mkv", ".avi", ".webm", ".m4v"}
def _extract_video_file_frames_as_b64(
video_path: str,
max_frames: int = 5,
stride: int = 16,
) -> List[str]:
"""Extract up to `max_frames` JPEG base64 frames from a local video file.
- Uses a simple stride to subsample frames.
- Falls back to the first N frames if the video is very short.
- Returns an empty list on any error or if cv2 is unavailable.
"""
cv2 = _get_cv2()
if cv2 is None:
return []
try:
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
logger.warning("Could not open video file: %s", video_path)
return []
frames: List[str] = []
frame_index = 0
picked = 0
# Try stride sampling first
while picked < max_frames:
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index)
ok, frame = cap.read()
if not ok:
break
success, buf = cv2.imencode(".jpg", frame)
if success:
frames.append(base64.b64encode(buf.tobytes()).decode("ascii"))
picked += 1
frame_index += stride
# If we got nothing with stride, try the first sequential frames
if not frames:
cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
count = 0
while count < max_frames:
ok, frame = cap.read()
if not ok:
break
success, buf = cv2.imencode(".jpg", frame)
if success:
frames.append(base64.b64encode(buf.tobytes()).decode("ascii"))
count += 1
cap.release()
if frames:
logger.debug("Extracted %d frame(s) from video file %s", len(frames), video_path)
return frames
except Exception as e:
logger.warning("Error extracting frames from %s: %s", video_path, e, exc_info=True)
return []
# --- NEW STREAMING REQUEST FUNCTION (Example for Ollama) ---
# IMPORTANT: This needs to be adapted based on the actual API structure of the provider!
async def send_request_stream(
llm_provider: str,
base_ip: str,
port: str,
llm_model: str,
system_message: str,
user_message: str,
messages: List[Dict[str, str]],
seed: Optional[int] = None,
temperature: float = 0.7,
max_tokens: int = 1024,
random: bool = False,
top_k: int = 40,
top_p: float = 0.9,
repeat_penalty: float = 1.1,
stop: Optional[List[str]] = None,
keep_alive: Union[bool, str] = True,
llm_api_key: Optional[str] = None,
timeout: int = 120, # Add a timeout for the connection
base64_images: Optional[List[str]] = None,
) -> AsyncGenerator[str, None]:
"""
Sends a streaming request to an LLM provider (Example for Ollama).
Yields text chunks as they are received.
"""
provider_lower = llm_provider.lower()
if provider_lower in ["openai", "openrouter"]:
# --- OpenAI & OpenRouter Specific Streaming Logic ---
if not llm_api_key:
logger.error(f"{llm_provider} streaming requested but no API key supplied.")
yield f"[{llm_provider} Error: API key missing]"
return
# Check if this is a GPT-5 model and use Responses API (OpenAI only)
if provider_lower == "openai" and is_gpt5_model(llm_model):
logger.info(f"Detected GPT-5 model: {llm_model}, using Responses API for streaming")
try:
async for chunk in send_openai_responses_stream(
api_url="https://api.openai.com/v1/responses",
base64_images=base64_images,
model=llm_model,
system_message=system_message,
user_message=user_message,
messages=messages,
api_key=llm_api_key,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
):
yield chunk
return
except Exception as e:
logger.warning(f"GPT-5 Responses stream failed: {e}, falling back to Chat Completions")
# Fall through to regular OpenAI streaming below
api_url = "https://openrouter.ai/api/v1/chat/completions" if provider_lower == "openrouter" else "https://api.openai.com/v1/chat/completions"
headers = {
"Authorization": f"Bearer {llm_api_key}",
"Content-Type": "application/json",
}
# Build message list if not provided
if not messages:
messages = []
if system_message:
messages.append({"role": "system", "content": system_message})
# Handle multimodal user content (text + images)
if base64_images:
content_blocks = []
if user_message:
content_blocks.append({"type": "text", "text": user_message})
for img_b64 in base64_images:
content_blocks.append({
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{img_b64}"},
})
messages.append({"role": "user", "content": content_blocks})
else:
if user_message:
messages.append({"role": "user", "content": user_message})
payload = {
"model": llm_model,
"messages": messages,
"temperature": temperature,
"max_tokens": max_tokens,
"top_p": top_p,
"stream": True,
}
# Remove None values
payload = {k: v for k, v in payload.items() if v is not None}
logger.info(f"Streaming request to {llm_provider}: model={llm_model}")
session = None
try:
# Create session with custom connector for Windows
connector = aiohttp.TCPConnector(force_close=True) if sys.platform == 'win32' else None
session = aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=timeout),
connector=connector
)
async with session.post(api_url, headers=headers, json=payload) as response:
response.raise_for_status()
async for raw_line in response.content:
if not raw_line:
continue
line = raw_line.decode("utf-8").strip()
if not line:
continue
# OpenAI streams multiple lines that may begin with 'data:'; join those if needed.
if line.startswith("data: "):
data_str = line[len("data: ") :].strip()
else:
data_str = line
if data_str == "[DONE]":
break
try:
data_json = json.loads(data_str)
choices = data_json.get("choices", [])
if choices:
delta = choices[0].get("delta", {})
content_piece = delta.get("content")
if content_piece:
yield content_piece
except json.JSONDecodeError:
logger.warning(f"Could not decode JSON line from OpenAI stream: {data_str}")
except Exception as e:
logger.error(f"Error during OpenAI streaming: {e}", exc_info=True)
yield f"[OpenAI streaming error: {e}]"
finally:
# Ensure session is properly closed
if session and not session.closed:
await session.close()
# Small delay to allow cleanup on Windows
if sys.platform == 'win32':
await asyncio.sleep(0.1)
return
if provider_lower == "gemini":
if not llm_api_key:
logger.error("Gemini streaming requested but no API key supplied.")
yield "[Gemini Error: API key missing]"
return
async for chunk in send_gemini_request_stream(
api_key=llm_api_key,
model=llm_model,
system_message=system_message,
user_message=user_message,
messages=messages,
base64_images=base64_images,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
top_k=top_k,
):
yield chunk
return
if provider_lower == "ollama":
# --- Ollama Specific Streaming Logic ---
if not ensure_ollama_server(base_ip, port):
logger.error("Ollama daemon unavailable and could not be started – aborting stream.")
yield "[Error: Ollama daemon unavailable]"
return
# Ensure requested model is present locally (will pull if missing)
ensure_ollama_model(llm_model, base_ip, port)
# Ollama expects plain base64 strings without the 'data:image/...' prefix
if base64_images:
base64_images = [
img.split("base64,")[1] if isinstance(img, str) and "base64," in img else img
for img in base64_images
]
# Decide which Ollama endpoint to use:
# • /api/generate – fast text-only streaming (no image support)
# • /api/chat – full chat/completions with image support
# We switch to /api/chat automatically if the caller supplied base64-encoded images
# so that vision models (e.g. qwen-vl, llava) receive the image bytes.
use_chat_endpoint = bool(base64_images) # True if we have images to send
url = (
f"http://{base_ip}:{port}/api/chat" if use_chat_endpoint else f"http://{base_ip}:{port}/api/generate"
)
headers = {"Content-Type": "application/json"}
if llm_api_key: # Ollama doesn't typically use API keys this way, but include for consistency
headers["Authorization"] = f"Bearer {llm_api_key}"
# ------------------------------------------------------------------
# Build request payloads
# ------------------------------------------------------------------
if use_chat_endpoint:
# --------------------------------------------------------------
# /api/chat (supports multimodal, messages array)
# --------------------------------------------------------------
if not messages:
messages = []
if system_message:
messages.append({"role": "system", "content": system_message})
# Always append the user message at the end so that vision models
# get the most recent prompt + images in a single message.
user_msg: dict[str, Any] = {"role": "user", "content": user_message or ""}
if base64_images:
user_msg["images"] = base64_images # Ollama expects a list key called "images"
messages.append(user_msg)
payload = {
"model": llm_model,
"messages": messages,
"stream": True,
"options": {
"seed": seed,
"temperature": temperature,
"num_predict": max_tokens,
"top_k": top_k,
"top_p": top_p,
"repeat_penalty": repeat_penalty,
"stop": stop,
},
}
else:
# --------------------------------------------------------------
# /api/generate (text-only)
# --------------------------------------------------------------
# Construct messages list if not provided directly (for context)
if not messages:
messages = []
if system_message:
messages.append({"role": "system", "content": system_message})
if user_message:
messages.append({"role": "user", "content": user_message})
payload = {
"model": llm_model,
"prompt": user_message,
"system": system_message if system_message else None,
"stream": True,
"options": {
"seed": seed,
"temperature": temperature,
"num_predict": max_tokens,
"top_k": top_k,
"top_p": top_p,
"repeat_penalty": repeat_penalty,
"stop": stop,
},
}
# Clean up None values Ollama might not like
if not system_message:
del payload["system"]
# Remove None values from options dict
if "options" in payload:
payload["options"] = {k: v for k, v in payload["options"].items() if v is not None}
logger.info(f"Streaming request to Ollama: {url} with payload: {{'model': '{llm_model}', 'stream': True, 'use_chat': {use_chat_endpoint}}}")
try:
# Use a single session if possible, manage timeouts
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=timeout)) as session:
async with session.post(url, headers=headers, json=payload) as response:
response.raise_for_status() # Raise an exception for bad status codes (4xx or 5xx)
# Process the streaming response line by line. The JSON schema differs slightly
# between /api/generate and /api/chat, so we branch inside the loop.
async for line in response.content:
if not line:
continue
try:
decoded_line = line.decode("utf-8").strip()
if not decoded_line:
continue
data = json.loads(decoded_line)
if use_chat_endpoint:
# Chat endpoint returns {"message": {"content": "..."}, "done": bool}
msg = data.get("message", {})
chunk = msg.get("content", "")
is_done = data.get("done", False)
else:
# Generate endpoint returns {"response": "...", "done": bool}
chunk = data.get("response", "")
is_done = data.get("done", False)
if chunk:
yield chunk
if is_done:
logger.info("Ollama stream finished.")
break
except json.JSONDecodeError:
logger.warning(
f"Could not decode JSON line from Ollama stream: {line.decode('utf-8', errors='ignore')}"
)
except Exception as e:
logger.error(f"Error processing Ollama stream line: {e}", exc_info=True)
yield f"[Error processing stream: {e}]"
break # Stop streaming on error
except aiohttp.ClientConnectorError as e:
logger.error(f"Connection error to {url}: {e}")
yield f"[Connection Error: {e}]"
except aiohttp.ClientResponseError as e:
logger.error(f"HTTP error {e.status} from {url}: {e.message}")
# Attempt to read error details from response body
try:
error_body = await e.response.text() if hasattr(e, 'response') else 'No details'
logger.error(f"Error Body: {error_body}")
yield f"[HTTP Error {e.status}: {e.message} - {error_body[:100]}]"
except:
yield f"[HTTP Error {e.status}: {e.message}]"
except asyncio.TimeoutError:
logger.error(f"Request timed out after {timeout} seconds to {url}")
yield f"[Timeout Error]"
except ConnectionResetError as e:
logger.warning(f"Connection reset by peer: {e}")
yield f"[Connection Reset: The server closed the connection]"
except OSError as e:
if e.errno == 10054: # Windows specific: connection forcibly closed
logger.warning("Connection forcibly closed by remote host")
yield f"[Connection closed by server]"
else:
logger.error(f"OS Error during streaming: {e}")
yield f"[Network Error: {e}]"
except Exception as e:
logger.error(f"An unexpected error occurred during streaming request: {e}", exc_info=True)
yield f"[Unexpected Error: {e}]"
if provider_lower == "groq":
# --- Groq Specific Streaming Logic ---
if not llm_api_key:
logger.error("Groq streaming requested but no API key supplied.")
# Fallback could be added here if desired, for now just yield error
yield "[Groq Error: API key missing]"
return
groq_url = "https://api.groq.com/openai/v1/chat/completions"
headers = {
"Authorization": f"Bearer {llm_api_key}",
"Content-Type": "application/json",
}
# Handle multimodal user content (text + images)
is_vision_model = "scout" in llm_model or "maverick" in llm_model
images_to_send = base64_images
if images_to_send and not is_vision_model:
logger.warning("Groq stream: Model '%s' may not support images. Sending without.", llm_model)
images_to_send = None
elif images_to_send and len(images_to_send) > 5:
logger.warning("Groq stream: Taking first 5 of %s images for vision model.", len(images_to_send))
images_to_send = images_to_send[:5]
# Build message list if not provided
if not messages:
messages = []
if system_message:
messages.append({"role": "system", "content": system_message})
if images_to_send:
content_blocks = [{"type": "text", "text": user_message}]
for img_b64 in images_to_send:
content_blocks.append({
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{img_b64}"},
})
messages.append({"role": "user", "content": content_blocks})
else:
if user_message:
messages.append({"role": "user", "content": user_message})
payload = {
"model": llm_model,
"messages": messages,
"temperature": temperature,
"max_completion_tokens": max_tokens, # Use correct Groq param
"top_p": top_p,
"stream": True,
}
# Remove None values
payload = {k: v for k, v in payload.items() if v is not None}
logger.info(f"Streaming request to Groq: model={llm_model}")
session = None
try:
# Create session with custom connector for Windows
connector = aiohttp.TCPConnector(force_close=True) if sys.platform == 'win32' else None
session = aiohttp.ClientSession(
timeout=aiohttp.ClientTimeout(total=timeout),
connector=connector
)
async with session.post(groq_url, headers=headers, json=payload) as response:
response.raise_for_status()
async for raw_line in response.content:
line = raw_line.decode("utf-8").strip()
if not line: continue
if line.startswith("data: "):
data_str = line[len("data: ") :].strip()
else:
data_str = line
if data_str == "[DONE]":
break
try:
data_json = json.loads(data_str)
choices = data_json.get("choices", [])
if choices:
delta = choices[0].get("delta", {})
content_piece = delta.get("content")
if content_piece:
yield content_piece
except json.JSONDecodeError:
logger.warning(f"Could not decode JSON line from Groq stream: {data_str}")
except Exception as e:
logger.error(f"Error during Groq streaming: {e}", exc_info=True)
yield f"[Groq streaming error: {e}]"
finally:
# Ensure session is properly closed
if session and not session.closed:
await session.close()
# Small delay to allow cleanup on Windows
if sys.platform == 'win32':
await asyncio.sleep(0.1)
return
# Existing fallback logic for other providers
if provider_lower not in ["ollama", "openai", "openrouter", "transformers", "hf", "local", "groq", "gemini"]:
logger.warning(f"Streaming not implemented for provider '{llm_provider}'. Falling back to non-streaming.")
try:
full_response_data = await send_request(
llm_provider=llm_provider, base_ip=base_ip, port=port, images=base64_images, llm_model=llm_model,
system_message=system_message, user_message=user_message, messages=messages, seed=seed,
temperature=temperature, max_tokens=max_tokens, random=random, top_k=top_k, top_p=top_p,
repeat_penalty=repeat_penalty, stop=stop, keep_alive=keep_alive, llm_api_key=llm_api_key
)
if isinstance(full_response_data, dict):
if "choices" in full_response_data and full_response_data["choices"]:
message = full_response_data["choices"][0].get("message", {})
content = message.get("content", "")
if content: yield content
elif "response" in full_response_data:
if full_response_data["response"]: yield full_response_data["response"]
elif "candidates" in full_response_data and full_response_data.get("candidates"):
try:
content = full_response_data["candidates"][0]["content"]["parts"][0]["text"]
if content: yield content
except (KeyError, IndexError, TypeError):
logger.warning(f"Could not parse Gemini response format: {full_response_data}")
else:
logger.error(f"Unexpected non-streaming response format: {full_response_data}")
elif isinstance(full_response_data, str):
if full_response_data: yield full_response_data
else:
logger.error(f"Unexpected non-streaming response type: {type(full_response_data)}")
except Exception as e:
logger.error(f"Error in fallback non-streaming request for {llm_provider}: {e}", exc_info=True)
yield f"[Error: {e}]"
return # Stop generation after yielding the fallback response
def _remove_thinking_tags(text: str) -> str:
"""Remove ... blocks from text, including the tags themselves."""
import re
# Pattern to match ... or ◁think▷...◁/think▷
pattern = r'.*?|◁think▷.*?◁/think▷'
cleaned = re.sub(pattern, '', text, flags=re.DOTALL)
# Clean up any extra whitespace/newlines left behind
cleaned = re.sub(r'\n\s*\n\s*\n', '\n\n', cleaned) # Replace multiple newlines with double
return cleaned.strip()
# --- Original Node (for reference or non-streaming use) ---
class LLMToolkitTextGenerator:
DEFAULT_PROVIDER = "openai"
DEFAULT_MODEL: str = "gpt-4o-mini"
MODEL_LIST: List[str] = [DEFAULT_MODEL]
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"prompt": ("STRING", {"multiline": False, "default": "Write a short story about a robot learning to paint."}),
"hide_thinking": ("BOOLEAN", {"default": True, "tooltip": "Hide model thinking process (content between tags)"})
},
"optional": {
"context": ("*", {}),
"temperature": ("FLOAT", {"default": 0.7, "min": 0.0, "max": 2.0, "step": 0.05}),
"max_tokens": ("INT", {"default": 1024, "min": 1, "max": 65536}),
"top_p": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 1.0, "step": 0.05}),
"top_k": ("INT", {"default": 40, "min": 0, "max": 500}),
"seed": ("INT", {"default": -1, "min": -1, "max": 2147483647, "tooltip": "-1 = random seed"})
},
"hidden": {
"llm_model": ("STRING", {"default": cls.DEFAULT_MODEL})
}
}
RETURN_TYPES = ("*", "STRING")
RETURN_NAMES = ("context", "text")
FUNCTION = "generate"
CATEGORY = "🔗llm_toolkit/generators"
OUTPUT_NODE = True # Keeps the text widget for non-streaming version
def generate(self, prompt, hide_thinking, context=None, temperature=0.7, max_tokens=1024, top_p=0.9, top_k=40, seed=-1, llm_model=None):
# ... (original generate logic using run_async(send_request(...))) ...
# This function remains mostly the same as the user provided,
# calling the original non-streaming send_request.
# We'll copy the parameter processing logic from the streaming version
# for consistency, but it will call the non-streaming send_request.
try:
# Base parameter defaults
params = {
"llm_provider": self.DEFAULT_PROVIDER,
"llm_model": llm_model or self.DEFAULT_MODEL,
"system_message": "You are a helpful, creative, and concise assistant.",
"user_message": prompt,
"base_ip": "localhost",
"port": "11434",
"temperature": temperature,
"max_tokens": max_tokens,
"top_p": top_p,
"top_k": top_k,
"seed": seed if seed != -1 else None,
"repeat_penalty": 1.1,
"stop": None,
"keep_alive": True,
"messages": [],
}
provider_config = None
if context is not None:
if isinstance(context, dict) and "provider_name" in context:
provider_config = context
elif isinstance(context, dict) and "provider_config" in context:
provider_config = context["provider_config"]
if provider_config and isinstance(provider_config, dict):
if "provider_name" in provider_config: params["llm_provider"] = provider_config["provider_name"]
if "api_key" in provider_config: params["llm_api_key"] = provider_config["api_key"]
if "base_ip" in provider_config: params["base_ip"] = provider_config["base_ip"]
if "port" in provider_config: params["port"] = provider_config["port"]
for key in provider_config:
if key not in ["provider_name", "llm_model", "api_key", "base_ip", "port", "user_prompt"]:
params[key] = provider_config[key]
if "user_prompt" in provider_config: params["user_message"] = provider_config["user_prompt"]
provider_model = provider_config.get("llm_model", "")
if provider_model:
params["llm_model"] = provider_model
else:
params["llm_model"] = "" # Will be handled below
if params.get("llm_provider"):
provider = str(params["llm_provider"]).lower()
if not params.get("llm_model"):
PROVIDER_DEFAULTS = {"openai": "gpt-4o-mini", "anthropic": "claude-3-opus-20240229"}
fallback = PROVIDER_DEFAULTS.get(provider)
if fallback: params["llm_model"] = fallback
# --- Groq reasoning-format handling ---
if provider == "groq":
# Hide thinking via API when requested to avoid post-processing.
params["reasoning_format"] = "hidden" if hide_thinking else "raw"
# Auto-fetch API key for OpenAI if missing/placeholder
if provider == "openai" and (not params.get("llm_api_key") or params["llm_api_key"] in {"", "1234", None}):
try:
params["llm_api_key"] = get_api_key("OPENAI_API_KEY", "openai")
logger.info("generate: Retrieved OpenAI API key via get_api_key helper.")
except ValueError as _e:
logger.warning(f"generate: get_api_key failed – {_e}")
# --- Prompt-Manager support ---------------------------------------------------
# If upstream PromptManager attached a prompt_config dict, use the data
prompt_cfg = None
if context is not None and isinstance(context, dict):
prompt_cfg = context.get("prompt_config")
if prompt_cfg and isinstance(prompt_cfg, dict):
# Override user prompt text if provided
if prompt_cfg.get("text"):
params["user_message"] = prompt_cfg["text"]
# Collect images and video frames (already in base64 from PromptManager)
imgs = []
if prompt_cfg.get("image_base64"):
img_val = prompt_cfg["image_base64"]
if isinstance(img_val, str):
imgs.extend([img_val])
elif isinstance(img_val, list):
imgs.extend(img_val)
# Add extracted video frames
if prompt_cfg.get("video_frames_base64"):
vid_frames = prompt_cfg["video_frames_base64"]
if isinstance(vid_frames, list):
imgs.extend(vid_frames)
else:
imgs.append(vid_frames)
params["images"] = imgs if imgs else None
# Pass through file paths and URLs for APIs that support them
if prompt_cfg.get("file_paths"):
params["file_paths"] = prompt_cfg["file_paths"]
logger.debug(f"PromptManager: Passing file_paths to API (some APIs process videos/PDFs directly).")
if prompt_cfg.get("urls"):
params["urls"] = prompt_cfg["urls"]
logger.debug(f"PromptManager: Passing URLs to API.")
# Audio path remains separate
if prompt_cfg.get("audio_path"):
params["audio_path"] = prompt_cfg["audio_path"]
else:
params["images"] = None
# --- New: If file_paths contain local video files, extract a few frames ---
try:
file_paths_value = params.get("file_paths")
candidate_paths: List[str] = []
if isinstance(file_paths_value, str) and file_paths_value.strip():
candidate_paths = [file_paths_value.strip()]
elif isinstance(file_paths_value, list):
candidate_paths = [p for p in file_paths_value if isinstance(p, str) and p.strip()]
extracted_from_files: List[str] = []
for p in candidate_paths:
if _is_video_file(p) and os.path.isfile(p):
extracted_from_files.extend(_extract_video_file_frames_as_b64(p))
if extracted_from_files:
if params.get("images") is None:
params["images"] = []
if not isinstance(params["images"], list):
params["images"] = [params["images"]] # normalize
# cap total images to 8 to avoid huge payloads
remaining = max(0, 8 - len(params["images"]))
params["images"].extend(extracted_from_files[:remaining])
logger.info("Added %d frame(s) extracted from video file paths to images payload.", min(len(extracted_from_files), remaining))
except Exception as e:
logger.warning("Failed to process video file paths for frame extraction: %s", e, exc_info=True)
log_params = _sanitize_params_for_log(params)
logger.info(f"[Non-Streaming] Making LLM request with params: {log_params}")
try:
# --- CALL NON-STREAMING VERSION ---
response_data = run_async(
send_request( # Original non-streaming call
llm_provider=params["llm_provider"],
base_ip=params.get("base_ip", "localhost"),
port=params.get("port", "11434"),
images=params.get("images"),
llm_model=params["llm_model"],
system_message=params["system_message"],
user_message=params["user_message"],
messages=params["messages"],
seed=params.get("seed"),
temperature=params["temperature"],
max_tokens=params["max_tokens"],
random=params.get("random", False),
top_k=params["top_k"],
top_p=params["top_p"],
repeat_penalty=params["repeat_penalty"],
stop=params.get("stop"),
keep_alive=params.get("keep_alive", True),
llm_api_key=params.get("llm_api_key"),
reasoning_format=params.get("reasoning_format") # Pass reasoning_format
)
)
# --- END NON-STREAMING CALL ---
except Exception as e:
logger.error(f"Error in non-streaming send_request call: {e}", exc_info=True)
response_data = {"choices": [{"message": {"content": f"Error calling send_request: {str(e)}"}}]}
if response_data is None:
content = "Error: Received None response"
elif isinstance(response_data, dict):
if "choices" in response_data and response_data["choices"]:
message = response_data["choices"][0].get("message", {})
content = message.get("content", "")
if content is None:
content = "Error: Null content in response"
elif "response" in response_data:
content = response_data["response"]
else:
content = f"Error: Unexpected format: {str(response_data)}"
elif isinstance(response_data, str):
content = response_data
else:
content = f"Error: Unexpected type: {type(response_data)}"
# Apply thinking tag removal if requested
if hide_thinking and content:
content = _remove_thinking_tags(content)
if context is not None and isinstance(context, dict):
context_out = context.copy()
context_out["llm_response"] = content
context_out["llm_raw_response"] = response_data
else:
context_out = {"llm_response": content, "llm_raw_response": response_data, "passthrough_data": context}
payload = ContextPayload(content, context_out)
return {"ui": {"string": [content]}, "result": (payload, str(payload))}
except Exception as e:
error_message = f"Error generating text: {str(e)}"
logger.error(error_message, exc_info=True)
error_output = {"error": error_message, "original_input": context}
payload = ContextPayload(error_message, error_output)
return {"ui": {"string": [error_message]}, "result": (payload, str(payload))}
# --- Helper to avoid dumping large base64 blobs to INFO log -----------------
def _sanitize_params_for_log(d: dict) -> dict:
"""Return a shallow copy with huge fields replaced by short placeholders."""
out = {}
for k, v in d.items():
# Hide large binary blobs (images, video frames)
if k in {"images", "video_frames", "base64_images"} and v:
# Replace with concise placeholder
if isinstance(v, (list, tuple)):
out[k] = f"[{len(v)} item(s) base64 omitted]"
elif isinstance(v, str):
out[k] = "[base64 string omitted]"
else:
out[k] = "[data omitted]"
continue
# Mask any values that look like API keys so we never log them in full
key_lc = k.lower()
if "api_key" in key_lc or key_lc.endswith("_key") or key_lc.endswith("key"):
if isinstance(v, str) and v:
# Keep first 5 chars for debugging, mask the rest
masked = v[:5] + "…" if len(v) > 5 else "***"
out[k] = masked
else:
out[k] = "[key hidden]"
continue
# Default passthrough
else:
out[k] = v
return out
# --- NEW STREAMING NODE ---
class LLMToolkitTextGeneratorStream:
DEFAULT_PROVIDER = "openai"
DEFAULT_MODEL: str = "gpt-4o-mini"
MODEL_LIST: List[str] = [DEFAULT_MODEL]
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"prompt": ("STRING", {"multiline": False, "default": "Write a detailed description of a futuristic city."}),
"hide_thinking": ("BOOLEAN", {"default": True, "tooltip": "Hide model thinking process (content between tags)"})
},
"optional": {
"context": ("*", {}),
"temperature": ("FLOAT", {"default": 0.7, "min": 0.0, "max": 2.0, "step": 0.05}),
"max_tokens": ("INT", {"default": 1024, "min": 1, "max": 65536}),
"top_p": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 1.0, "step": 0.05}),
"top_k": ("INT", {"default": 40, "min": 0, "max": 500}),
"seed": ("INT", {"default": -1, "min": -1, "max": 2147483647, "tooltip": "-1 = random seed"})
},
"hidden": { # <-- Add hidden inputs
"unique_id": "UNIQUE_ID",
"llm_model": ("STRING", {"default": cls.DEFAULT_MODEL})
},
}
RETURN_TYPES = ("*", "STRING")
RETURN_NAMES = ("context", "text")
FUNCTION = "generate_stream" # <-- Use new function name
CATEGORY = "🔗llm_toolkit/generators"
OUTPUT_NODE = True # Keep the JS widget logic
def generate_stream(self, prompt, hide_thinking, unique_id, llm_model=None, context=None, **kwargs):
"""
Generates text using the specified provider and streams the response back
to the UI via websocket messages. (synchronous wrapper)
"""
# Wrap the previous async implementation inside an inner coroutine
async def _async_generate():
# Previous async body START
if PromptServer is None:
logger.error("PromptServer not available. Cannot stream.")
error_msg = "Streaming requires PromptServer, which is not available."
error_output = {"error": error_msg, "original_input": context}
payload = ContextPayload(error_msg, error_output)
return {"ui": {"string": [error_msg]}, "result": (payload, str(payload))}
server = PromptServer.instance
full_response_text = ""
thinking_buffer = ""
inside_thinking = False
try:
# --- Corrected Parameter Processing Logic ---
# 1. Start with node defaults
params = {
"llm_provider": self.DEFAULT_PROVIDER,
"llm_model": llm_model or self.DEFAULT_MODEL,
"system_message": "You are a helpful, creative, and concise assistant.",
"user_message": prompt,
"base_ip": "localhost", "port": "11434",
"temperature": kwargs.get('temperature', 0.7), "max_tokens": kwargs.get('max_tokens', 1024), "top_p": kwargs.get('top_p', 0.9), "top_k": kwargs.get('top_k', 40), "seed": kwargs.get('seed', -1),
"repeat_penalty": 1.1, "stop": None, "keep_alive": "5m",
"messages": [], "llm_api_key": None,
}
# 2. Apply general settings from the incoming context
if context and isinstance(context, dict):
for key, value in context.items():
if key in params and value is not None:
params[key] = value
# 3. Apply specific provider_config settings, which take precedence
provider_config = None
if context and isinstance(context, dict):
if "provider_config" in context and isinstance(context["provider_config"], dict):
provider_config = context["provider_config"]
elif "provider_name" in context: # Handle flat context from old nodes
provider_config = context
if provider_config:
# First, update all matching keys from provider_config that also exist in params
params.update({k: v for k, v in provider_config.items() if k in params and v is not None})
# Second, handle specific key name differences. This ensures provider_name is mapped correctly.
if "provider_name" in provider_config and provider_config["provider_name"]:
params["llm_provider"] = provider_config["provider_name"]
elif context and isinstance(context, dict) and context.get("provider_name"):
# Fallback to root-level key if nested provider_config didn't have it
params["llm_provider"] = context["provider_name"]
if "api_key" in provider_config and provider_config["api_key"]:
params["llm_api_key"] = provider_config["api_key"]
elif context and isinstance(context, dict) and context.get("api_key"):
params["llm_api_key"] = context["api_key"]
if "llm_model" in provider_config and provider_config["llm_model"]:
params["llm_model"] = provider_config["llm_model"]
elif context and isinstance(context, dict) and context.get("llm_model"):
params["llm_model"] = context["llm_model"]
# 4. The user_message from context or provider_config can be used,
# but the node's prompt input has the final say.
if context and isinstance(context, dict) and "user_prompt" in context and context["user_prompt"]:
params["user_message"] = context.get("user_prompt")
if provider_config and "user_prompt" in provider_config and provider_config["user_prompt"]:
params["user_message"] = provider_config.get("user_prompt")
if prompt: # Node input is final override, if provided
params["user_message"] = prompt
# --- End Corrected Parameter Processing ---
# Finalize model name fallback
provider = "unknown"
if params.get("llm_provider"):
provider = str(params["llm_provider"]).lower()
if not params.get("llm_model"):
PROVIDER_DEFAULTS = {"openai": "gpt-4o-mini", "anthropic": "claude-3-opus-20240229"}
fallback = PROVIDER_DEFAULTS.get(provider)
params["llm_model"] = fallback or self.DEFAULT_MODEL
# Groq reasoning format
if provider == "groq":
params["reasoning_format"] = "hidden" if hide_thinking else "raw"
# Auto-fetch API key for OpenAI if missing/placeholder
if provider == "openai" and (not params.get("llm_api_key") or params["llm_api_key"] in {"", "1234", None}):
try:
params["llm_api_key"] = get_api_key("OPENAI_API_KEY", "openai")
logger.info("generate_stream: Retrieved OpenAI API key via get_api_key helper.")
except ValueError as _e:
logger.warning(f"generate_stream: get_api_key failed – {_e}")
# --- Prompt-Manager support ---------------------------------------------------
# If upstream PromptManager attached a prompt_config dict, use the data
prompt_cfg = None
if context is not None and isinstance(context, dict):
prompt_cfg = context.get("prompt_config")
if prompt_cfg and isinstance(prompt_cfg, dict):
# Override user prompt text if provided
if prompt_cfg.get("text"):
params["user_message"] = prompt_cfg["text"]
# Collect images and video frames (already in base64 from PromptManager)
imgs = []
if prompt_cfg.get("image_base64"):
img_val = prompt_cfg["image_base64"]
if isinstance(img_val, str):
imgs.extend([img_val])
elif isinstance(img_val, list):
imgs.extend(img_val)
# Add extracted video frames
if prompt_cfg.get("video_frames_base64"):
vid_frames = prompt_cfg["video_frames_base64"]
if isinstance(vid_frames, list):
imgs.extend(vid_frames)
else:
imgs.append(vid_frames)
params["images"] = imgs if imgs else None
# Pass through file paths and URLs for APIs that support them
if prompt_cfg.get("file_paths"):
params["file_paths"] = prompt_cfg["file_paths"]
logger.debug(f"PromptManager: Passing file_paths to API (some APIs process videos/PDFs directly).")
if prompt_cfg.get("urls"):
params["urls"] = prompt_cfg["urls"]
logger.debug(f"PromptManager: Passing URLs to API.")
# Audio path remains separate
if prompt_cfg.get("audio_path"):
params["audio_path"] = prompt_cfg["audio_path"]
else:
params["images"] = None
# --- New: If file_paths contain local video files, extract a few frames ---
try:
file_paths_value = params.get("file_paths")
candidate_paths: List[str] = []
if isinstance(file_paths_value, str) and file_paths_value.strip():
candidate_paths = [file_paths_value.strip()]
elif isinstance(file_paths_value, list):
candidate_paths = [p for p in file_paths_value if isinstance(p, str) and p.strip()]
extracted_from_files: List[str] = []
for p in candidate_paths:
if _is_video_file(p) and os.path.isfile(p):
extracted_from_files.extend(_extract_video_file_frames_as_b64(p))
if extracted_from_files:
if params.get("images") is None:
params["images"] = []
if not isinstance(params["images"], list):
params["images"] = [params["images"]] # normalize
# cap total images to 8 to avoid huge payloads
remaining = max(0, 8 - len(params["images"]))
params["images"].extend(extracted_from_files[:remaining])
logger.info(
"[Streaming] Added %d frame(s) extracted from video file paths to images payload.",
min(len(extracted_from_files), remaining),
)
except Exception as e:
logger.warning(
"[Streaming] Failed to process video file paths for frame extraction: %s",
e,
exc_info=True,
)
log_params = _sanitize_params_for_log(params)
logger.info(
f"[Streaming] Initiating LLM stream with params: {log_params} for node {unique_id}"
)
# --- Send START message ---
server.send_sync("llmtoolkit.stream.start", {"node": unique_id}, sid=server.client_id)
# --- Initiate and process the stream ---
stream_generator = send_request_stream(
llm_provider=params["llm_provider"],
base_ip=params.get("base_ip", "localhost"),
port=params.get("port", "11434"),
llm_model=params["llm_model"],
system_message=params["system_message"],
user_message=params["user_message"],
messages=params["messages"],
seed=params.get("seed"),
temperature=params["temperature"],
max_tokens=params["max_tokens"],
random=params.get("random", False),
top_k=params["top_k"],
top_p=params["top_p"],
repeat_penalty=params["repeat_penalty"],
stop=params.get("stop"),
keep_alive=params.get("keep_alive", True),
llm_api_key=params.get("llm_api_key"),
base64_images=params.get("images"),
)
async for chunk in stream_generator:
if chunk:
full_response_text += chunk
# Handle thinking tag filtering for streaming
if hide_thinking:
# Track thinking state and buffer content
chunk_to_send = ""
i = 0
while i < len(chunk):
if not inside_thinking:
# Check for start of thinking tag
if chunk[i:i+7] == '':
inside_thinking = True
thinking_buffer = ''
i += 7
continue
elif chunk[i:i+9] == '◁think▷': # <-- Add Kimi-VL start tag
inside_thinking = True
thinking_buffer = '◁think▷'
i += 9
continue
else:
chunk_to_send += chunk[i]
else:
# Inside thinking block, buffer until we find closing tag
thinking_buffer += chunk[i]
if thinking_buffer.endswith(''):
inside_thinking = False
thinking_buffer = ""
i += 1
continue
elif thinking_buffer.endswith('◁/think▷'): # <-- Add Kimi-VL end tag
inside_thinking = False
thinking_buffer = ""
i += 1
continue
i += 1
# Only send non-thinking content
if chunk_to_send:
server.send_sync(
"llmtoolkit.stream.chunk",
{"node": unique_id, "text": chunk_to_send},
sid=server.client_id,
)
else:
# Send all content if not hiding thinking
server.send_sync(
"llmtoolkit.stream.chunk",
{"node": unique_id, "text": chunk},
sid=server.client_id,
)
await asyncio.sleep(0.001)
# Apply thinking tag removal to final text if requested
final_text = full_response_text
if hide_thinking:
final_text = _remove_thinking_tags(full_response_text)
logger.info(
f"[Streaming] Finished for node {unique_id}. Total length: {len(final_text)}"
)
server.send_sync(
"llmtoolkit.stream.end",
{"node": unique_id, "final_text": final_text},
sid=server.client_id,
)
# --- Prepare final context output ---
if context is not None and isinstance(context, dict):
context_out = context.copy()
context_out["llm_response"] = final_text
context_out["llm_raw_response"] = {
"status": "Streamed successfully",
"final_length": len(final_text),
}
else:
context_out = {
"llm_response": final_text,
"llm_raw_response": {
"status": "Streamed successfully",
"final_length": len(final_text),
},
"passthrough_data": context,
}
payload = ContextPayload(final_text, context_out)
return {"ui": {"string": [final_text]}, "result": (payload, str(payload))}
except Exception as e:
error_message = f"Error during streaming generation: {str(e)}"
logger.error(error_message, exc_info=True)
if server and unique_id:
server.send_sync(
"llmtoolkit.stream.error",
{"node": unique_id, "error": error_message},
sid=server.client_id,
)
error_output = {
"error": error_message,
"partial_response": full_response_text,
"original_input": context,
}
payload = ContextPayload(error_message, error_output)
return {
"ui": {"string": [f"Error: {error_message}\nPartial: {full_response_text}"]},
"result": (payload, str(payload)),
}
# Previous async body END
# Execute the inner coroutine and return its result synchronously
return run_async(_async_generate())
# Node Mappings for ComfyUI
NODE_CLASS_MAPPINGS = {
"LLMToolkitTextGenerator": LLMToolkitTextGenerator, # Keep original
"LLMToolkitTextGeneratorStream": LLMToolkitTextGeneratorStream # Add streaming version
}
NODE_DISPLAY_NAME_MAPPINGS = {
"LLMToolkitTextGenerator": "Generate Text (🔗LLMToolkit)",
"LLMToolkitTextGeneratorStream": "Generate Text Stream (🔗LLMToolkit)" # New display name
}