mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-03 21:20:49 +08:00
1262 lines
59 KiB
Python
1262 lines
59 KiB
Python
# generate_text.py
|
||
import os
|
||
import sys
|
||
import json
|
||
import base64
|
||
import asyncio
|
||
import logging
|
||
from typing import Dict, Any, List, Optional, Union, Tuple, AsyncGenerator
|
||
|
||
# Set up logging
|
||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# Add parent directory to path to ensure imports work correctly
|
||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||
parent_dir = os.path.dirname(current_dir)
|
||
if parent_dir not in sys.path:
|
||
sys.path.insert(0, parent_dir)
|
||
|
||
# Try to import ComfyUI-specific modules
|
||
try:
|
||
import folder_paths
|
||
from server import PromptServer # <-- Import PromptServer
|
||
except ImportError:
|
||
logger.warning("Could not import folder_paths or server. Make sure ComfyUI environment is set up.")
|
||
folder_paths = None
|
||
PromptServer = None
|
||
|
||
# Check for required dependencies
|
||
missing_deps = []
|
||
try:
|
||
import aiohttp
|
||
except ImportError:
|
||
missing_deps.append("aiohttp")
|
||
|
||
if missing_deps:
|
||
logger.warning(f"Missing dependencies: {', '.join(missing_deps)}. Some functionality may not work.")
|
||
logger.warning("Please install missing dependencies: pip install " + " ".join(missing_deps))
|
||
|
||
# Import utility functions (assuming they exist)
|
||
from send_request import send_request, run_async, is_gpt5_model # Keep non-streaming version if needed
|
||
from api.openai_api import send_openai_responses_stream
|
||
from api.gemini_api import send_gemini_request_stream
|
||
from llmtoolkit_utils import query_local_ollama_models, ensure_ollama_server, ensure_ollama_model, get_api_key
|
||
|
||
# Local transformers streaming (optional) - removed unused import
|
||
# The send_transformers_request_stream function was imported but never used
|
||
send_transformers_request_stream = None # type: ignore
|
||
|
||
# Payload helper to embed context into a string subclass
|
||
from context_payload import ContextPayload
|
||
|
||
# -----------------------------------------------------------------------------
|
||
# Helpers: Lazy OpenCV import and video -> base64 frame extraction
|
||
# -----------------------------------------------------------------------------
|
||
_cv2_ref = None
|
||
|
||
def _get_cv2():
|
||
global _cv2_ref
|
||
if _cv2_ref is None:
|
||
try:
|
||
import cv2 as _cv2 # type: ignore
|
||
_cv2_ref = _cv2
|
||
except Exception:
|
||
_cv2_ref = None
|
||
logger.warning("cv2 not available. Video file frame extraction disabled.")
|
||
return _cv2_ref
|
||
|
||
|
||
def _is_video_file(path: str) -> bool:
|
||
try:
|
||
ext = os.path.splitext(path)[1].lower()
|
||
except Exception:
|
||
return False
|
||
return ext in {".mp4", ".mov", ".mkv", ".avi", ".webm", ".m4v"}
|
||
|
||
|
||
def _extract_video_file_frames_as_b64(
|
||
video_path: str,
|
||
max_frames: int = 5,
|
||
stride: int = 16,
|
||
) -> List[str]:
|
||
"""Extract up to `max_frames` JPEG base64 frames from a local video file.
|
||
|
||
- Uses a simple stride to subsample frames.
|
||
- Falls back to the first N frames if the video is very short.
|
||
- Returns an empty list on any error or if cv2 is unavailable.
|
||
"""
|
||
cv2 = _get_cv2()
|
||
if cv2 is None:
|
||
return []
|
||
|
||
try:
|
||
cap = cv2.VideoCapture(video_path)
|
||
if not cap.isOpened():
|
||
logger.warning("Could not open video file: %s", video_path)
|
||
return []
|
||
|
||
frames: List[str] = []
|
||
frame_index = 0
|
||
picked = 0
|
||
|
||
# Try stride sampling first
|
||
while picked < max_frames:
|
||
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index)
|
||
ok, frame = cap.read()
|
||
if not ok:
|
||
break
|
||
success, buf = cv2.imencode(".jpg", frame)
|
||
if success:
|
||
frames.append(base64.b64encode(buf.tobytes()).decode("ascii"))
|
||
picked += 1
|
||
frame_index += stride
|
||
|
||
# If we got nothing with stride, try the first sequential frames
|
||
if not frames:
|
||
cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
|
||
count = 0
|
||
while count < max_frames:
|
||
ok, frame = cap.read()
|
||
if not ok:
|
||
break
|
||
success, buf = cv2.imencode(".jpg", frame)
|
||
if success:
|
||
frames.append(base64.b64encode(buf.tobytes()).decode("ascii"))
|
||
count += 1
|
||
|
||
cap.release()
|
||
if frames:
|
||
logger.debug("Extracted %d frame(s) from video file %s", len(frames), video_path)
|
||
return frames
|
||
except Exception as e:
|
||
logger.warning("Error extracting frames from %s: %s", video_path, e, exc_info=True)
|
||
return []
|
||
|
||
# --- NEW STREAMING REQUEST FUNCTION (Example for Ollama) ---
|
||
# IMPORTANT: This needs to be adapted based on the actual API structure of the provider!
|
||
async def send_request_stream(
|
||
llm_provider: str,
|
||
base_ip: str,
|
||
port: str,
|
||
llm_model: str,
|
||
system_message: str,
|
||
user_message: str,
|
||
messages: List[Dict[str, str]],
|
||
seed: Optional[int] = None,
|
||
temperature: float = 0.7,
|
||
max_tokens: int = 1024,
|
||
random: bool = False,
|
||
top_k: int = 40,
|
||
top_p: float = 0.9,
|
||
repeat_penalty: float = 1.1,
|
||
stop: Optional[List[str]] = None,
|
||
keep_alive: Union[bool, str] = True,
|
||
llm_api_key: Optional[str] = None,
|
||
timeout: int = 120, # Add a timeout for the connection
|
||
base64_images: Optional[List[str]] = None,
|
||
) -> AsyncGenerator[str, None]:
|
||
"""
|
||
Sends a streaming request to an LLM provider (Example for Ollama).
|
||
Yields text chunks as they are received.
|
||
"""
|
||
provider_lower = llm_provider.lower()
|
||
|
||
if provider_lower in ["openai", "openrouter"]:
|
||
# --- OpenAI & OpenRouter Specific Streaming Logic ---
|
||
if not llm_api_key:
|
||
logger.error(f"{llm_provider} streaming requested but no API key supplied.")
|
||
yield f"[{llm_provider} Error: API key missing]"
|
||
return
|
||
|
||
# Check if this is a GPT-5 model and use Responses API (OpenAI only)
|
||
if provider_lower == "openai" and is_gpt5_model(llm_model):
|
||
logger.info(f"Detected GPT-5 model: {llm_model}, using Responses API for streaming")
|
||
try:
|
||
async for chunk in send_openai_responses_stream(
|
||
api_url="https://api.openai.com/v1/responses",
|
||
base64_images=base64_images,
|
||
model=llm_model,
|
||
system_message=system_message,
|
||
user_message=user_message,
|
||
messages=messages,
|
||
api_key=llm_api_key,
|
||
temperature=temperature,
|
||
max_tokens=max_tokens,
|
||
top_p=top_p,
|
||
):
|
||
yield chunk
|
||
return
|
||
except Exception as e:
|
||
logger.warning(f"GPT-5 Responses stream failed: {e}, falling back to Chat Completions")
|
||
# Fall through to regular OpenAI streaming below
|
||
|
||
api_url = "https://openrouter.ai/api/v1/chat/completions" if provider_lower == "openrouter" else "https://api.openai.com/v1/chat/completions"
|
||
|
||
headers = {
|
||
"Authorization": f"Bearer {llm_api_key}",
|
||
"Content-Type": "application/json",
|
||
}
|
||
# Build message list if not provided
|
||
if not messages:
|
||
messages = []
|
||
if system_message:
|
||
messages.append({"role": "system", "content": system_message})
|
||
|
||
# Handle multimodal user content (text + images)
|
||
if base64_images:
|
||
content_blocks = []
|
||
if user_message:
|
||
content_blocks.append({"type": "text", "text": user_message})
|
||
for img_b64 in base64_images:
|
||
content_blocks.append({
|
||
"type": "image_url",
|
||
"image_url": {"url": f"data:image/jpeg;base64,{img_b64}"},
|
||
})
|
||
messages.append({"role": "user", "content": content_blocks})
|
||
else:
|
||
if user_message:
|
||
messages.append({"role": "user", "content": user_message})
|
||
|
||
payload = {
|
||
"model": llm_model,
|
||
"messages": messages,
|
||
"temperature": temperature,
|
||
"max_tokens": max_tokens,
|
||
"top_p": top_p,
|
||
"stream": True,
|
||
}
|
||
# Remove None values
|
||
payload = {k: v for k, v in payload.items() if v is not None}
|
||
|
||
logger.info(f"Streaming request to {llm_provider}: model={llm_model}")
|
||
session = None
|
||
try:
|
||
# Create session with custom connector for Windows
|
||
connector = aiohttp.TCPConnector(force_close=True) if sys.platform == 'win32' else None
|
||
session = aiohttp.ClientSession(
|
||
timeout=aiohttp.ClientTimeout(total=timeout),
|
||
connector=connector
|
||
)
|
||
|
||
async with session.post(api_url, headers=headers, json=payload) as response:
|
||
response.raise_for_status()
|
||
async for raw_line in response.content:
|
||
if not raw_line:
|
||
continue
|
||
line = raw_line.decode("utf-8").strip()
|
||
if not line:
|
||
continue
|
||
# OpenAI streams multiple lines that may begin with 'data:'; join those if needed.
|
||
if line.startswith("data: "):
|
||
data_str = line[len("data: ") :].strip()
|
||
else:
|
||
data_str = line
|
||
if data_str == "[DONE]":
|
||
break
|
||
try:
|
||
data_json = json.loads(data_str)
|
||
choices = data_json.get("choices", [])
|
||
if choices:
|
||
delta = choices[0].get("delta", {})
|
||
content_piece = delta.get("content")
|
||
if content_piece:
|
||
yield content_piece
|
||
except json.JSONDecodeError:
|
||
logger.warning(f"Could not decode JSON line from OpenAI stream: {data_str}")
|
||
except Exception as e:
|
||
logger.error(f"Error during OpenAI streaming: {e}", exc_info=True)
|
||
yield f"[OpenAI streaming error: {e}]"
|
||
finally:
|
||
# Ensure session is properly closed
|
||
if session and not session.closed:
|
||
await session.close()
|
||
# Small delay to allow cleanup on Windows
|
||
if sys.platform == 'win32':
|
||
await asyncio.sleep(0.1)
|
||
return
|
||
|
||
if provider_lower == "gemini":
|
||
if not llm_api_key:
|
||
logger.error("Gemini streaming requested but no API key supplied.")
|
||
yield "[Gemini Error: API key missing]"
|
||
return
|
||
|
||
async for chunk in send_gemini_request_stream(
|
||
api_key=llm_api_key,
|
||
model=llm_model,
|
||
system_message=system_message,
|
||
user_message=user_message,
|
||
messages=messages,
|
||
base64_images=base64_images,
|
||
temperature=temperature,
|
||
max_tokens=max_tokens,
|
||
top_p=top_p,
|
||
top_k=top_k,
|
||
):
|
||
yield chunk
|
||
return
|
||
|
||
if provider_lower == "ollama":
|
||
# --- Ollama Specific Streaming Logic ---
|
||
if not ensure_ollama_server(base_ip, port):
|
||
logger.error("Ollama daemon unavailable and could not be started – aborting stream.")
|
||
yield "[Error: Ollama daemon unavailable]"
|
||
return
|
||
|
||
# Ensure requested model is present locally (will pull if missing)
|
||
ensure_ollama_model(llm_model, base_ip, port)
|
||
|
||
# Ollama expects plain base64 strings without the 'data:image/...' prefix
|
||
if base64_images:
|
||
base64_images = [
|
||
img.split("base64,")[1] if isinstance(img, str) and "base64," in img else img
|
||
for img in base64_images
|
||
]
|
||
|
||
# Decide which Ollama endpoint to use:
|
||
# • /api/generate – fast text-only streaming (no image support)
|
||
# • /api/chat – full chat/completions with image support
|
||
# We switch to /api/chat automatically if the caller supplied base64-encoded images
|
||
# so that vision models (e.g. qwen-vl, llava) receive the image bytes.
|
||
|
||
use_chat_endpoint = bool(base64_images) # True if we have images to send
|
||
|
||
url = (
|
||
f"http://{base_ip}:{port}/api/chat" if use_chat_endpoint else f"http://{base_ip}:{port}/api/generate"
|
||
)
|
||
|
||
headers = {"Content-Type": "application/json"}
|
||
if llm_api_key: # Ollama doesn't typically use API keys this way, but include for consistency
|
||
headers["Authorization"] = f"Bearer {llm_api_key}"
|
||
|
||
# ------------------------------------------------------------------
|
||
# Build request payloads
|
||
# ------------------------------------------------------------------
|
||
|
||
if use_chat_endpoint:
|
||
# --------------------------------------------------------------
|
||
# /api/chat (supports multimodal, messages array)
|
||
# --------------------------------------------------------------
|
||
if not messages:
|
||
messages = []
|
||
if system_message:
|
||
messages.append({"role": "system", "content": system_message})
|
||
|
||
# Always append the user message at the end so that vision models
|
||
# get the most recent prompt + images in a single message.
|
||
user_msg: dict[str, Any] = {"role": "user", "content": user_message or ""}
|
||
if base64_images:
|
||
user_msg["images"] = base64_images # Ollama expects a list key called "images"
|
||
messages.append(user_msg)
|
||
|
||
payload = {
|
||
"model": llm_model,
|
||
"messages": messages,
|
||
"stream": True,
|
||
"options": {
|
||
"seed": seed,
|
||
"temperature": temperature,
|
||
"num_predict": max_tokens,
|
||
"top_k": top_k,
|
||
"top_p": top_p,
|
||
"repeat_penalty": repeat_penalty,
|
||
"stop": stop,
|
||
},
|
||
}
|
||
else:
|
||
# --------------------------------------------------------------
|
||
# /api/generate (text-only)
|
||
# --------------------------------------------------------------
|
||
# Construct messages list if not provided directly (for context)
|
||
if not messages:
|
||
messages = []
|
||
if system_message:
|
||
messages.append({"role": "system", "content": system_message})
|
||
if user_message:
|
||
messages.append({"role": "user", "content": user_message})
|
||
|
||
payload = {
|
||
"model": llm_model,
|
||
"prompt": user_message,
|
||
"system": system_message if system_message else None,
|
||
"stream": True,
|
||
"options": {
|
||
"seed": seed,
|
||
"temperature": temperature,
|
||
"num_predict": max_tokens,
|
||
"top_k": top_k,
|
||
"top_p": top_p,
|
||
"repeat_penalty": repeat_penalty,
|
||
"stop": stop,
|
||
},
|
||
}
|
||
# Clean up None values Ollama might not like
|
||
if not system_message:
|
||
del payload["system"]
|
||
|
||
# Remove None values from options dict
|
||
if "options" in payload:
|
||
payload["options"] = {k: v for k, v in payload["options"].items() if v is not None}
|
||
|
||
logger.info(f"Streaming request to Ollama: {url} with payload: {{'model': '{llm_model}', 'stream': True, 'use_chat': {use_chat_endpoint}}}")
|
||
|
||
try:
|
||
# Use a single session if possible, manage timeouts
|
||
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=timeout)) as session:
|
||
async with session.post(url, headers=headers, json=payload) as response:
|
||
response.raise_for_status() # Raise an exception for bad status codes (4xx or 5xx)
|
||
|
||
# Process the streaming response line by line. The JSON schema differs slightly
|
||
# between /api/generate and /api/chat, so we branch inside the loop.
|
||
async for line in response.content:
|
||
if not line:
|
||
continue
|
||
try:
|
||
decoded_line = line.decode("utf-8").strip()
|
||
if not decoded_line:
|
||
continue
|
||
|
||
data = json.loads(decoded_line)
|
||
|
||
if use_chat_endpoint:
|
||
# Chat endpoint returns {"message": {"content": "..."}, "done": bool}
|
||
msg = data.get("message", {})
|
||
chunk = msg.get("content", "")
|
||
is_done = data.get("done", False)
|
||
else:
|
||
# Generate endpoint returns {"response": "...", "done": bool}
|
||
chunk = data.get("response", "")
|
||
is_done = data.get("done", False)
|
||
|
||
if chunk:
|
||
yield chunk
|
||
|
||
if is_done:
|
||
logger.info("Ollama stream finished.")
|
||
break
|
||
except json.JSONDecodeError:
|
||
logger.warning(
|
||
f"Could not decode JSON line from Ollama stream: {line.decode('utf-8', errors='ignore')}"
|
||
)
|
||
except Exception as e:
|
||
logger.error(f"Error processing Ollama stream line: {e}", exc_info=True)
|
||
yield f"[Error processing stream: {e}]"
|
||
break # Stop streaming on error
|
||
|
||
except aiohttp.ClientConnectorError as e:
|
||
logger.error(f"Connection error to {url}: {e}")
|
||
yield f"[Connection Error: {e}]"
|
||
except aiohttp.ClientResponseError as e:
|
||
logger.error(f"HTTP error {e.status} from {url}: {e.message}")
|
||
# Attempt to read error details from response body
|
||
try:
|
||
error_body = await e.response.text() if hasattr(e, 'response') else 'No details'
|
||
logger.error(f"Error Body: {error_body}")
|
||
yield f"[HTTP Error {e.status}: {e.message} - {error_body[:100]}]"
|
||
except:
|
||
yield f"[HTTP Error {e.status}: {e.message}]"
|
||
except asyncio.TimeoutError:
|
||
logger.error(f"Request timed out after {timeout} seconds to {url}")
|
||
yield f"[Timeout Error]"
|
||
except ConnectionResetError as e:
|
||
logger.warning(f"Connection reset by peer: {e}")
|
||
yield f"[Connection Reset: The server closed the connection]"
|
||
except OSError as e:
|
||
if e.errno == 10054: # Windows specific: connection forcibly closed
|
||
logger.warning("Connection forcibly closed by remote host")
|
||
yield f"[Connection closed by server]"
|
||
else:
|
||
logger.error(f"OS Error during streaming: {e}")
|
||
yield f"[Network Error: {e}]"
|
||
except Exception as e:
|
||
logger.error(f"An unexpected error occurred during streaming request: {e}", exc_info=True)
|
||
yield f"[Unexpected Error: {e}]"
|
||
|
||
if provider_lower == "groq":
|
||
# --- Groq Specific Streaming Logic ---
|
||
if not llm_api_key:
|
||
logger.error("Groq streaming requested but no API key supplied.")
|
||
# Fallback could be added here if desired, for now just yield error
|
||
yield "[Groq Error: API key missing]"
|
||
return
|
||
|
||
groq_url = "https://api.groq.com/openai/v1/chat/completions"
|
||
headers = {
|
||
"Authorization": f"Bearer {llm_api_key}",
|
||
"Content-Type": "application/json",
|
||
}
|
||
|
||
# Handle multimodal user content (text + images)
|
||
is_vision_model = "scout" in llm_model or "maverick" in llm_model
|
||
images_to_send = base64_images
|
||
if images_to_send and not is_vision_model:
|
||
logger.warning("Groq stream: Model '%s' may not support images. Sending without.", llm_model)
|
||
images_to_send = None
|
||
elif images_to_send and len(images_to_send) > 5:
|
||
logger.warning("Groq stream: Taking first 5 of %s images for vision model.", len(images_to_send))
|
||
images_to_send = images_to_send[:5]
|
||
|
||
# Build message list if not provided
|
||
if not messages:
|
||
messages = []
|
||
if system_message:
|
||
messages.append({"role": "system", "content": system_message})
|
||
|
||
if images_to_send:
|
||
content_blocks = [{"type": "text", "text": user_message}]
|
||
for img_b64 in images_to_send:
|
||
content_blocks.append({
|
||
"type": "image_url",
|
||
"image_url": {"url": f"data:image/jpeg;base64,{img_b64}"},
|
||
})
|
||
messages.append({"role": "user", "content": content_blocks})
|
||
else:
|
||
if user_message:
|
||
messages.append({"role": "user", "content": user_message})
|
||
|
||
payload = {
|
||
"model": llm_model,
|
||
"messages": messages,
|
||
"temperature": temperature,
|
||
"max_completion_tokens": max_tokens, # Use correct Groq param
|
||
"top_p": top_p,
|
||
"stream": True,
|
||
}
|
||
# Remove None values
|
||
payload = {k: v for k, v in payload.items() if v is not None}
|
||
|
||
logger.info(f"Streaming request to Groq: model={llm_model}")
|
||
session = None
|
||
try:
|
||
# Create session with custom connector for Windows
|
||
connector = aiohttp.TCPConnector(force_close=True) if sys.platform == 'win32' else None
|
||
session = aiohttp.ClientSession(
|
||
timeout=aiohttp.ClientTimeout(total=timeout),
|
||
connector=connector
|
||
)
|
||
|
||
async with session.post(groq_url, headers=headers, json=payload) as response:
|
||
response.raise_for_status()
|
||
async for raw_line in response.content:
|
||
line = raw_line.decode("utf-8").strip()
|
||
if not line: continue
|
||
if line.startswith("data: "):
|
||
data_str = line[len("data: ") :].strip()
|
||
else:
|
||
data_str = line
|
||
if data_str == "[DONE]":
|
||
break
|
||
try:
|
||
data_json = json.loads(data_str)
|
||
choices = data_json.get("choices", [])
|
||
if choices:
|
||
delta = choices[0].get("delta", {})
|
||
content_piece = delta.get("content")
|
||
if content_piece:
|
||
yield content_piece
|
||
except json.JSONDecodeError:
|
||
logger.warning(f"Could not decode JSON line from Groq stream: {data_str}")
|
||
except Exception as e:
|
||
logger.error(f"Error during Groq streaming: {e}", exc_info=True)
|
||
yield f"[Groq streaming error: {e}]"
|
||
finally:
|
||
# Ensure session is properly closed
|
||
if session and not session.closed:
|
||
await session.close()
|
||
# Small delay to allow cleanup on Windows
|
||
if sys.platform == 'win32':
|
||
await asyncio.sleep(0.1)
|
||
return
|
||
|
||
# Existing fallback logic for other providers
|
||
if provider_lower not in ["ollama", "openai", "openrouter", "transformers", "hf", "local", "groq", "gemini"]:
|
||
logger.warning(f"Streaming not implemented for provider '{llm_provider}'. Falling back to non-streaming.")
|
||
try:
|
||
full_response_data = await send_request(
|
||
llm_provider=llm_provider, base_ip=base_ip, port=port, images=base64_images, llm_model=llm_model,
|
||
system_message=system_message, user_message=user_message, messages=messages, seed=seed,
|
||
temperature=temperature, max_tokens=max_tokens, random=random, top_k=top_k, top_p=top_p,
|
||
repeat_penalty=repeat_penalty, stop=stop, keep_alive=keep_alive, llm_api_key=llm_api_key
|
||
)
|
||
if isinstance(full_response_data, dict):
|
||
if "choices" in full_response_data and full_response_data["choices"]:
|
||
message = full_response_data["choices"][0].get("message", {})
|
||
content = message.get("content", "")
|
||
if content: yield content
|
||
elif "response" in full_response_data:
|
||
if full_response_data["response"]: yield full_response_data["response"]
|
||
elif "candidates" in full_response_data and full_response_data.get("candidates"):
|
||
try:
|
||
content = full_response_data["candidates"][0]["content"]["parts"][0]["text"]
|
||
if content: yield content
|
||
except (KeyError, IndexError, TypeError):
|
||
logger.warning(f"Could not parse Gemini response format: {full_response_data}")
|
||
else:
|
||
logger.error(f"Unexpected non-streaming response format: {full_response_data}")
|
||
elif isinstance(full_response_data, str):
|
||
if full_response_data: yield full_response_data
|
||
else:
|
||
logger.error(f"Unexpected non-streaming response type: {type(full_response_data)}")
|
||
|
||
except Exception as e:
|
||
logger.error(f"Error in fallback non-streaming request for {llm_provider}: {e}", exc_info=True)
|
||
yield f"[Error: {e}]"
|
||
return # Stop generation after yielding the fallback response
|
||
|
||
def _remove_thinking_tags(text: str) -> str:
|
||
"""Remove <think>...</think> blocks from text, including the tags themselves."""
|
||
import re
|
||
# Pattern to match <think>...</think> or ◁think▷...◁/think▷
|
||
pattern = r'<think>.*?</think>|◁think▷.*?◁/think▷'
|
||
cleaned = re.sub(pattern, '', text, flags=re.DOTALL)
|
||
# Clean up any extra whitespace/newlines left behind
|
||
cleaned = re.sub(r'\n\s*\n\s*\n', '\n\n', cleaned) # Replace multiple newlines with double
|
||
return cleaned.strip()
|
||
|
||
# --- Original Node (for reference or non-streaming use) ---
|
||
class LLMToolkitTextGenerator:
|
||
DEFAULT_PROVIDER = "openai"
|
||
|
||
DEFAULT_MODEL: str = "gpt-4o-mini"
|
||
|
||
MODEL_LIST: List[str] = [DEFAULT_MODEL]
|
||
|
||
@classmethod
|
||
def INPUT_TYPES(cls):
|
||
return {
|
||
"required": {
|
||
"prompt": ("STRING", {"multiline": False, "default": "Write a short story about a robot learning to paint."}),
|
||
"hide_thinking": ("BOOLEAN", {"default": True, "tooltip": "Hide model thinking process (content between <think> tags)"})
|
||
},
|
||
"optional": {
|
||
"context": ("*", {}),
|
||
"temperature": ("FLOAT", {"default": 0.7, "min": 0.0, "max": 2.0, "step": 0.05}),
|
||
"max_tokens": ("INT", {"default": 1024, "min": 1, "max": 65536}),
|
||
"top_p": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 1.0, "step": 0.05}),
|
||
"top_k": ("INT", {"default": 40, "min": 0, "max": 500}),
|
||
"seed": ("INT", {"default": -1, "min": -1, "max": 2147483647, "tooltip": "-1 = random seed"})
|
||
},
|
||
"hidden": {
|
||
"llm_model": ("STRING", {"default": cls.DEFAULT_MODEL})
|
||
}
|
||
}
|
||
|
||
RETURN_TYPES = ("*", "STRING")
|
||
RETURN_NAMES = ("context", "text")
|
||
FUNCTION = "generate"
|
||
CATEGORY = "🔗llm_toolkit/generators"
|
||
OUTPUT_NODE = True # Keeps the text widget for non-streaming version
|
||
|
||
def generate(self, prompt, hide_thinking, context=None, temperature=0.7, max_tokens=1024, top_p=0.9, top_k=40, seed=-1, llm_model=None):
|
||
# ... (original generate logic using run_async(send_request(...))) ...
|
||
# This function remains mostly the same as the user provided,
|
||
# calling the original non-streaming send_request.
|
||
# We'll copy the parameter processing logic from the streaming version
|
||
# for consistency, but it will call the non-streaming send_request.
|
||
try:
|
||
# Base parameter defaults
|
||
params = {
|
||
"llm_provider": self.DEFAULT_PROVIDER,
|
||
"llm_model": llm_model or self.DEFAULT_MODEL,
|
||
"system_message": "You are a helpful, creative, and concise assistant.",
|
||
"user_message": prompt,
|
||
"base_ip": "localhost",
|
||
"port": "11434",
|
||
"temperature": temperature,
|
||
"max_tokens": max_tokens,
|
||
"top_p": top_p,
|
||
"top_k": top_k,
|
||
"seed": seed if seed != -1 else None,
|
||
"repeat_penalty": 1.1,
|
||
"stop": None,
|
||
"keep_alive": True,
|
||
"messages": [],
|
||
}
|
||
|
||
provider_config = None
|
||
if context is not None:
|
||
if isinstance(context, dict) and "provider_name" in context:
|
||
provider_config = context
|
||
elif isinstance(context, dict) and "provider_config" in context:
|
||
provider_config = context["provider_config"]
|
||
|
||
if provider_config and isinstance(provider_config, dict):
|
||
if "provider_name" in provider_config: params["llm_provider"] = provider_config["provider_name"]
|
||
if "api_key" in provider_config: params["llm_api_key"] = provider_config["api_key"]
|
||
if "base_ip" in provider_config: params["base_ip"] = provider_config["base_ip"]
|
||
if "port" in provider_config: params["port"] = provider_config["port"]
|
||
for key in provider_config:
|
||
if key not in ["provider_name", "llm_model", "api_key", "base_ip", "port", "user_prompt"]:
|
||
params[key] = provider_config[key]
|
||
if "user_prompt" in provider_config: params["user_message"] = provider_config["user_prompt"]
|
||
provider_model = provider_config.get("llm_model", "")
|
||
if provider_model:
|
||
params["llm_model"] = provider_model
|
||
else:
|
||
params["llm_model"] = "" # Will be handled below
|
||
|
||
if params.get("llm_provider"):
|
||
provider = str(params["llm_provider"]).lower()
|
||
if not params.get("llm_model"):
|
||
PROVIDER_DEFAULTS = {"openai": "gpt-4o-mini", "anthropic": "claude-3-opus-20240229"}
|
||
fallback = PROVIDER_DEFAULTS.get(provider)
|
||
if fallback: params["llm_model"] = fallback
|
||
|
||
# --- Groq reasoning-format handling ---
|
||
if provider == "groq":
|
||
# Hide thinking via API when requested to avoid post-processing.
|
||
params["reasoning_format"] = "hidden" if hide_thinking else "raw"
|
||
|
||
# Auto-fetch API key for OpenAI if missing/placeholder
|
||
if provider == "openai" and (not params.get("llm_api_key") or params["llm_api_key"] in {"", "1234", None}):
|
||
try:
|
||
params["llm_api_key"] = get_api_key("OPENAI_API_KEY", "openai")
|
||
logger.info("generate: Retrieved OpenAI API key via get_api_key helper.")
|
||
except ValueError as _e:
|
||
logger.warning(f"generate: get_api_key failed – {_e}")
|
||
|
||
# --- Prompt-Manager support ---------------------------------------------------
|
||
# If upstream PromptManager attached a prompt_config dict, use the data
|
||
prompt_cfg = None
|
||
if context is not None and isinstance(context, dict):
|
||
prompt_cfg = context.get("prompt_config")
|
||
|
||
if prompt_cfg and isinstance(prompt_cfg, dict):
|
||
# Override user prompt text if provided
|
||
if prompt_cfg.get("text"):
|
||
params["user_message"] = prompt_cfg["text"]
|
||
|
||
# Collect images and video frames (already in base64 from PromptManager)
|
||
imgs = []
|
||
if prompt_cfg.get("image_base64"):
|
||
img_val = prompt_cfg["image_base64"]
|
||
if isinstance(img_val, str):
|
||
imgs.extend([img_val])
|
||
elif isinstance(img_val, list):
|
||
imgs.extend(img_val)
|
||
|
||
# Add extracted video frames
|
||
if prompt_cfg.get("video_frames_base64"):
|
||
vid_frames = prompt_cfg["video_frames_base64"]
|
||
if isinstance(vid_frames, list):
|
||
imgs.extend(vid_frames)
|
||
else:
|
||
imgs.append(vid_frames)
|
||
|
||
params["images"] = imgs if imgs else None
|
||
|
||
# Pass through file paths and URLs for APIs that support them
|
||
if prompt_cfg.get("file_paths"):
|
||
params["file_paths"] = prompt_cfg["file_paths"]
|
||
logger.debug(f"PromptManager: Passing file_paths to API (some APIs process videos/PDFs directly).")
|
||
|
||
if prompt_cfg.get("urls"):
|
||
params["urls"] = prompt_cfg["urls"]
|
||
logger.debug(f"PromptManager: Passing URLs to API.")
|
||
|
||
# Audio path remains separate
|
||
if prompt_cfg.get("audio_path"):
|
||
params["audio_path"] = prompt_cfg["audio_path"]
|
||
else:
|
||
params["images"] = None
|
||
|
||
# --- New: If file_paths contain local video files, extract a few frames ---
|
||
try:
|
||
file_paths_value = params.get("file_paths")
|
||
candidate_paths: List[str] = []
|
||
if isinstance(file_paths_value, str) and file_paths_value.strip():
|
||
candidate_paths = [file_paths_value.strip()]
|
||
elif isinstance(file_paths_value, list):
|
||
candidate_paths = [p for p in file_paths_value if isinstance(p, str) and p.strip()]
|
||
|
||
extracted_from_files: List[str] = []
|
||
for p in candidate_paths:
|
||
if _is_video_file(p) and os.path.isfile(p):
|
||
extracted_from_files.extend(_extract_video_file_frames_as_b64(p))
|
||
|
||
if extracted_from_files:
|
||
if params.get("images") is None:
|
||
params["images"] = []
|
||
if not isinstance(params["images"], list):
|
||
params["images"] = [params["images"]] # normalize
|
||
# cap total images to 8 to avoid huge payloads
|
||
remaining = max(0, 8 - len(params["images"]))
|
||
params["images"].extend(extracted_from_files[:remaining])
|
||
logger.info("Added %d frame(s) extracted from video file paths to images payload.", min(len(extracted_from_files), remaining))
|
||
except Exception as e:
|
||
logger.warning("Failed to process video file paths for frame extraction: %s", e, exc_info=True)
|
||
|
||
log_params = _sanitize_params_for_log(params)
|
||
logger.info(f"[Non-Streaming] Making LLM request with params: {log_params}")
|
||
|
||
try:
|
||
# --- CALL NON-STREAMING VERSION ---
|
||
response_data = run_async(
|
||
send_request( # Original non-streaming call
|
||
llm_provider=params["llm_provider"],
|
||
base_ip=params.get("base_ip", "localhost"),
|
||
port=params.get("port", "11434"),
|
||
images=params.get("images"),
|
||
llm_model=params["llm_model"],
|
||
system_message=params["system_message"],
|
||
user_message=params["user_message"],
|
||
messages=params["messages"],
|
||
seed=params.get("seed"),
|
||
temperature=params["temperature"],
|
||
max_tokens=params["max_tokens"],
|
||
random=params.get("random", False),
|
||
top_k=params["top_k"],
|
||
top_p=params["top_p"],
|
||
repeat_penalty=params["repeat_penalty"],
|
||
stop=params.get("stop"),
|
||
keep_alive=params.get("keep_alive", True),
|
||
llm_api_key=params.get("llm_api_key"),
|
||
reasoning_format=params.get("reasoning_format") # Pass reasoning_format
|
||
)
|
||
)
|
||
# --- END NON-STREAMING CALL ---
|
||
except Exception as e:
|
||
logger.error(f"Error in non-streaming send_request call: {e}", exc_info=True)
|
||
response_data = {"choices": [{"message": {"content": f"Error calling send_request: {str(e)}"}}]}
|
||
|
||
if response_data is None:
|
||
content = "Error: Received None response"
|
||
elif isinstance(response_data, dict):
|
||
if "choices" in response_data and response_data["choices"]:
|
||
message = response_data["choices"][0].get("message", {})
|
||
content = message.get("content", "")
|
||
if content is None:
|
||
content = "Error: Null content in response"
|
||
elif "response" in response_data:
|
||
content = response_data["response"]
|
||
else:
|
||
content = f"Error: Unexpected format: {str(response_data)}"
|
||
elif isinstance(response_data, str):
|
||
content = response_data
|
||
else:
|
||
content = f"Error: Unexpected type: {type(response_data)}"
|
||
|
||
# Apply thinking tag removal if requested
|
||
if hide_thinking and content:
|
||
content = _remove_thinking_tags(content)
|
||
|
||
if context is not None and isinstance(context, dict):
|
||
context_out = context.copy()
|
||
context_out["llm_response"] = content
|
||
context_out["llm_raw_response"] = response_data
|
||
else:
|
||
context_out = {"llm_response": content, "llm_raw_response": response_data, "passthrough_data": context}
|
||
|
||
payload = ContextPayload(content, context_out)
|
||
return {"ui": {"string": [content]}, "result": (payload, str(payload))}
|
||
|
||
except Exception as e:
|
||
error_message = f"Error generating text: {str(e)}"
|
||
logger.error(error_message, exc_info=True)
|
||
error_output = {"error": error_message, "original_input": context}
|
||
payload = ContextPayload(error_message, error_output)
|
||
return {"ui": {"string": [error_message]}, "result": (payload, str(payload))}
|
||
|
||
# --- Helper to avoid dumping large base64 blobs to INFO log -----------------
|
||
def _sanitize_params_for_log(d: dict) -> dict:
|
||
"""Return a shallow copy with huge fields replaced by short placeholders."""
|
||
out = {}
|
||
for k, v in d.items():
|
||
# Hide large binary blobs (images, video frames)
|
||
if k in {"images", "video_frames", "base64_images"} and v:
|
||
# Replace with concise placeholder
|
||
if isinstance(v, (list, tuple)):
|
||
out[k] = f"[{len(v)} item(s) base64 omitted]"
|
||
elif isinstance(v, str):
|
||
out[k] = "[base64 string omitted]"
|
||
else:
|
||
out[k] = "[data omitted]"
|
||
continue
|
||
|
||
# Mask any values that look like API keys so we never log them in full
|
||
key_lc = k.lower()
|
||
if "api_key" in key_lc or key_lc.endswith("_key") or key_lc.endswith("key"):
|
||
if isinstance(v, str) and v:
|
||
# Keep first 5 chars for debugging, mask the rest
|
||
masked = v[:5] + "…" if len(v) > 5 else "***"
|
||
out[k] = masked
|
||
else:
|
||
out[k] = "[key hidden]"
|
||
continue
|
||
|
||
# Default passthrough
|
||
else:
|
||
out[k] = v
|
||
return out
|
||
|
||
# --- NEW STREAMING NODE ---
|
||
class LLMToolkitTextGeneratorStream:
|
||
DEFAULT_PROVIDER = "openai"
|
||
|
||
DEFAULT_MODEL: str = "gpt-4o-mini"
|
||
|
||
MODEL_LIST: List[str] = [DEFAULT_MODEL]
|
||
|
||
@classmethod
|
||
def INPUT_TYPES(cls):
|
||
return {
|
||
"required": {
|
||
"prompt": ("STRING", {"multiline": False, "default": "Write a detailed description of a futuristic city."}),
|
||
"hide_thinking": ("BOOLEAN", {"default": True, "tooltip": "Hide model thinking process (content between <think> tags)"})
|
||
},
|
||
"optional": {
|
||
"context": ("*", {}),
|
||
"temperature": ("FLOAT", {"default": 0.7, "min": 0.0, "max": 2.0, "step": 0.05}),
|
||
"max_tokens": ("INT", {"default": 1024, "min": 1, "max": 65536}),
|
||
"top_p": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 1.0, "step": 0.05}),
|
||
"top_k": ("INT", {"default": 40, "min": 0, "max": 500}),
|
||
"seed": ("INT", {"default": -1, "min": -1, "max": 2147483647, "tooltip": "-1 = random seed"})
|
||
},
|
||
"hidden": { # <-- Add hidden inputs
|
||
"unique_id": "UNIQUE_ID",
|
||
"llm_model": ("STRING", {"default": cls.DEFAULT_MODEL})
|
||
},
|
||
}
|
||
|
||
RETURN_TYPES = ("*", "STRING")
|
||
RETURN_NAMES = ("context", "text")
|
||
FUNCTION = "generate_stream" # <-- Use new function name
|
||
CATEGORY = "🔗llm_toolkit/generators"
|
||
OUTPUT_NODE = True # Keep the JS widget logic
|
||
|
||
def generate_stream(self, prompt, hide_thinking, unique_id, llm_model=None, context=None, **kwargs):
|
||
"""
|
||
Generates text using the specified provider and streams the response back
|
||
to the UI via websocket messages. (synchronous wrapper)
|
||
"""
|
||
# Wrap the previous async implementation inside an inner coroutine
|
||
async def _async_generate():
|
||
# Previous async body START
|
||
if PromptServer is None:
|
||
logger.error("PromptServer not available. Cannot stream.")
|
||
error_msg = "Streaming requires PromptServer, which is not available."
|
||
error_output = {"error": error_msg, "original_input": context}
|
||
payload = ContextPayload(error_msg, error_output)
|
||
return {"ui": {"string": [error_msg]}, "result": (payload, str(payload))}
|
||
|
||
server = PromptServer.instance
|
||
full_response_text = ""
|
||
thinking_buffer = ""
|
||
inside_thinking = False
|
||
|
||
try:
|
||
# --- Corrected Parameter Processing Logic ---
|
||
# 1. Start with node defaults
|
||
params = {
|
||
"llm_provider": self.DEFAULT_PROVIDER,
|
||
"llm_model": llm_model or self.DEFAULT_MODEL,
|
||
"system_message": "You are a helpful, creative, and concise assistant.",
|
||
"user_message": prompt,
|
||
"base_ip": "localhost", "port": "11434",
|
||
"temperature": kwargs.get('temperature', 0.7), "max_tokens": kwargs.get('max_tokens', 1024), "top_p": kwargs.get('top_p', 0.9), "top_k": kwargs.get('top_k', 40), "seed": kwargs.get('seed', -1),
|
||
"repeat_penalty": 1.1, "stop": None, "keep_alive": "5m",
|
||
"messages": [], "llm_api_key": None,
|
||
}
|
||
|
||
# 2. Apply general settings from the incoming context
|
||
if context and isinstance(context, dict):
|
||
for key, value in context.items():
|
||
if key in params and value is not None:
|
||
params[key] = value
|
||
|
||
# 3. Apply specific provider_config settings, which take precedence
|
||
provider_config = None
|
||
if context and isinstance(context, dict):
|
||
if "provider_config" in context and isinstance(context["provider_config"], dict):
|
||
provider_config = context["provider_config"]
|
||
elif "provider_name" in context: # Handle flat context from old nodes
|
||
provider_config = context
|
||
|
||
if provider_config:
|
||
# First, update all matching keys from provider_config that also exist in params
|
||
params.update({k: v for k, v in provider_config.items() if k in params and v is not None})
|
||
|
||
# Second, handle specific key name differences. This ensures provider_name is mapped correctly.
|
||
if "provider_name" in provider_config and provider_config["provider_name"]:
|
||
params["llm_provider"] = provider_config["provider_name"]
|
||
elif context and isinstance(context, dict) and context.get("provider_name"):
|
||
# Fallback to root-level key if nested provider_config didn't have it
|
||
params["llm_provider"] = context["provider_name"]
|
||
|
||
if "api_key" in provider_config and provider_config["api_key"]:
|
||
params["llm_api_key"] = provider_config["api_key"]
|
||
elif context and isinstance(context, dict) and context.get("api_key"):
|
||
params["llm_api_key"] = context["api_key"]
|
||
|
||
if "llm_model" in provider_config and provider_config["llm_model"]:
|
||
params["llm_model"] = provider_config["llm_model"]
|
||
elif context and isinstance(context, dict) and context.get("llm_model"):
|
||
params["llm_model"] = context["llm_model"]
|
||
|
||
# 4. The user_message from context or provider_config can be used,
|
||
# but the node's prompt input has the final say.
|
||
if context and isinstance(context, dict) and "user_prompt" in context and context["user_prompt"]:
|
||
params["user_message"] = context.get("user_prompt")
|
||
if provider_config and "user_prompt" in provider_config and provider_config["user_prompt"]:
|
||
params["user_message"] = provider_config.get("user_prompt")
|
||
if prompt: # Node input is final override, if provided
|
||
params["user_message"] = prompt
|
||
|
||
# --- End Corrected Parameter Processing ---
|
||
|
||
# Finalize model name fallback
|
||
provider = "unknown"
|
||
if params.get("llm_provider"):
|
||
provider = str(params["llm_provider"]).lower()
|
||
if not params.get("llm_model"):
|
||
PROVIDER_DEFAULTS = {"openai": "gpt-4o-mini", "anthropic": "claude-3-opus-20240229"}
|
||
fallback = PROVIDER_DEFAULTS.get(provider)
|
||
params["llm_model"] = fallback or self.DEFAULT_MODEL
|
||
|
||
# Groq reasoning format
|
||
if provider == "groq":
|
||
params["reasoning_format"] = "hidden" if hide_thinking else "raw"
|
||
|
||
# Auto-fetch API key for OpenAI if missing/placeholder
|
||
if provider == "openai" and (not params.get("llm_api_key") or params["llm_api_key"] in {"", "1234", None}):
|
||
try:
|
||
params["llm_api_key"] = get_api_key("OPENAI_API_KEY", "openai")
|
||
logger.info("generate_stream: Retrieved OpenAI API key via get_api_key helper.")
|
||
except ValueError as _e:
|
||
logger.warning(f"generate_stream: get_api_key failed – {_e}")
|
||
|
||
# --- Prompt-Manager support ---------------------------------------------------
|
||
# If upstream PromptManager attached a prompt_config dict, use the data
|
||
prompt_cfg = None
|
||
if context is not None and isinstance(context, dict):
|
||
prompt_cfg = context.get("prompt_config")
|
||
|
||
if prompt_cfg and isinstance(prompt_cfg, dict):
|
||
# Override user prompt text if provided
|
||
if prompt_cfg.get("text"):
|
||
params["user_message"] = prompt_cfg["text"]
|
||
|
||
# Collect images and video frames (already in base64 from PromptManager)
|
||
imgs = []
|
||
if prompt_cfg.get("image_base64"):
|
||
img_val = prompt_cfg["image_base64"]
|
||
if isinstance(img_val, str):
|
||
imgs.extend([img_val])
|
||
elif isinstance(img_val, list):
|
||
imgs.extend(img_val)
|
||
|
||
# Add extracted video frames
|
||
if prompt_cfg.get("video_frames_base64"):
|
||
vid_frames = prompt_cfg["video_frames_base64"]
|
||
if isinstance(vid_frames, list):
|
||
imgs.extend(vid_frames)
|
||
else:
|
||
imgs.append(vid_frames)
|
||
|
||
params["images"] = imgs if imgs else None
|
||
|
||
# Pass through file paths and URLs for APIs that support them
|
||
if prompt_cfg.get("file_paths"):
|
||
params["file_paths"] = prompt_cfg["file_paths"]
|
||
logger.debug(f"PromptManager: Passing file_paths to API (some APIs process videos/PDFs directly).")
|
||
|
||
if prompt_cfg.get("urls"):
|
||
params["urls"] = prompt_cfg["urls"]
|
||
logger.debug(f"PromptManager: Passing URLs to API.")
|
||
|
||
# Audio path remains separate
|
||
if prompt_cfg.get("audio_path"):
|
||
params["audio_path"] = prompt_cfg["audio_path"]
|
||
else:
|
||
params["images"] = None
|
||
|
||
# --- New: If file_paths contain local video files, extract a few frames ---
|
||
try:
|
||
file_paths_value = params.get("file_paths")
|
||
candidate_paths: List[str] = []
|
||
if isinstance(file_paths_value, str) and file_paths_value.strip():
|
||
candidate_paths = [file_paths_value.strip()]
|
||
elif isinstance(file_paths_value, list):
|
||
candidate_paths = [p for p in file_paths_value if isinstance(p, str) and p.strip()]
|
||
|
||
extracted_from_files: List[str] = []
|
||
for p in candidate_paths:
|
||
if _is_video_file(p) and os.path.isfile(p):
|
||
extracted_from_files.extend(_extract_video_file_frames_as_b64(p))
|
||
|
||
if extracted_from_files:
|
||
if params.get("images") is None:
|
||
params["images"] = []
|
||
if not isinstance(params["images"], list):
|
||
params["images"] = [params["images"]] # normalize
|
||
# cap total images to 8 to avoid huge payloads
|
||
remaining = max(0, 8 - len(params["images"]))
|
||
params["images"].extend(extracted_from_files[:remaining])
|
||
logger.info(
|
||
"[Streaming] Added %d frame(s) extracted from video file paths to images payload.",
|
||
min(len(extracted_from_files), remaining),
|
||
)
|
||
except Exception as e:
|
||
logger.warning(
|
||
"[Streaming] Failed to process video file paths for frame extraction: %s",
|
||
e,
|
||
exc_info=True,
|
||
)
|
||
|
||
log_params = _sanitize_params_for_log(params)
|
||
logger.info(
|
||
f"[Streaming] Initiating LLM stream with params: {log_params} for node {unique_id}"
|
||
)
|
||
|
||
# --- Send START message ---
|
||
server.send_sync("llmtoolkit.stream.start", {"node": unique_id}, sid=server.client_id)
|
||
|
||
# --- Initiate and process the stream ---
|
||
stream_generator = send_request_stream(
|
||
llm_provider=params["llm_provider"],
|
||
base_ip=params.get("base_ip", "localhost"),
|
||
port=params.get("port", "11434"),
|
||
llm_model=params["llm_model"],
|
||
system_message=params["system_message"],
|
||
user_message=params["user_message"],
|
||
messages=params["messages"],
|
||
seed=params.get("seed"),
|
||
temperature=params["temperature"],
|
||
max_tokens=params["max_tokens"],
|
||
random=params.get("random", False),
|
||
top_k=params["top_k"],
|
||
top_p=params["top_p"],
|
||
repeat_penalty=params["repeat_penalty"],
|
||
stop=params.get("stop"),
|
||
keep_alive=params.get("keep_alive", True),
|
||
llm_api_key=params.get("llm_api_key"),
|
||
base64_images=params.get("images"),
|
||
)
|
||
|
||
async for chunk in stream_generator:
|
||
if chunk:
|
||
full_response_text += chunk
|
||
|
||
# Handle thinking tag filtering for streaming
|
||
if hide_thinking:
|
||
# Track thinking state and buffer content
|
||
chunk_to_send = ""
|
||
i = 0
|
||
while i < len(chunk):
|
||
|
||
if not inside_thinking:
|
||
# Check for start of thinking tag
|
||
if chunk[i:i+7] == '<think>':
|
||
inside_thinking = True
|
||
thinking_buffer = '<think>'
|
||
i += 7
|
||
continue
|
||
elif chunk[i:i+9] == '◁think▷': # <-- Add Kimi-VL start tag
|
||
inside_thinking = True
|
||
thinking_buffer = '◁think▷'
|
||
i += 9
|
||
continue
|
||
else:
|
||
chunk_to_send += chunk[i]
|
||
else:
|
||
# Inside thinking block, buffer until we find closing tag
|
||
thinking_buffer += chunk[i]
|
||
if thinking_buffer.endswith('</think>'):
|
||
inside_thinking = False
|
||
thinking_buffer = ""
|
||
i += 1
|
||
continue
|
||
elif thinking_buffer.endswith('◁/think▷'): # <-- Add Kimi-VL end tag
|
||
inside_thinking = False
|
||
thinking_buffer = ""
|
||
i += 1
|
||
continue
|
||
i += 1
|
||
|
||
# Only send non-thinking content
|
||
if chunk_to_send:
|
||
server.send_sync(
|
||
"llmtoolkit.stream.chunk",
|
||
{"node": unique_id, "text": chunk_to_send},
|
||
sid=server.client_id,
|
||
)
|
||
else:
|
||
# Send all content if not hiding thinking
|
||
server.send_sync(
|
||
"llmtoolkit.stream.chunk",
|
||
{"node": unique_id, "text": chunk},
|
||
sid=server.client_id,
|
||
)
|
||
await asyncio.sleep(0.001)
|
||
|
||
# Apply thinking tag removal to final text if requested
|
||
final_text = full_response_text
|
||
if hide_thinking:
|
||
final_text = _remove_thinking_tags(full_response_text)
|
||
|
||
logger.info(
|
||
f"[Streaming] Finished for node {unique_id}. Total length: {len(final_text)}"
|
||
)
|
||
server.send_sync(
|
||
"llmtoolkit.stream.end",
|
||
{"node": unique_id, "final_text": final_text},
|
||
sid=server.client_id,
|
||
)
|
||
|
||
# --- Prepare final context output ---
|
||
if context is not None and isinstance(context, dict):
|
||
context_out = context.copy()
|
||
context_out["llm_response"] = final_text
|
||
context_out["llm_raw_response"] = {
|
||
"status": "Streamed successfully",
|
||
"final_length": len(final_text),
|
||
}
|
||
else:
|
||
context_out = {
|
||
"llm_response": final_text,
|
||
"llm_raw_response": {
|
||
"status": "Streamed successfully",
|
||
"final_length": len(final_text),
|
||
},
|
||
"passthrough_data": context,
|
||
}
|
||
|
||
payload = ContextPayload(final_text, context_out)
|
||
return {"ui": {"string": [final_text]}, "result": (payload, str(payload))}
|
||
|
||
except Exception as e:
|
||
error_message = f"Error during streaming generation: {str(e)}"
|
||
logger.error(error_message, exc_info=True)
|
||
if server and unique_id:
|
||
server.send_sync(
|
||
"llmtoolkit.stream.error",
|
||
{"node": unique_id, "error": error_message},
|
||
sid=server.client_id,
|
||
)
|
||
error_output = {
|
||
"error": error_message,
|
||
"partial_response": full_response_text,
|
||
"original_input": context,
|
||
}
|
||
payload = ContextPayload(error_message, error_output)
|
||
return {
|
||
"ui": {"string": [f"Error: {error_message}\nPartial: {full_response_text}"]},
|
||
"result": (payload, str(payload)),
|
||
}
|
||
# Previous async body END
|
||
|
||
# Execute the inner coroutine and return its result synchronously
|
||
return run_async(_async_generate())
|
||
|
||
|
||
# Node Mappings for ComfyUI
|
||
NODE_CLASS_MAPPINGS = {
|
||
"LLMToolkitTextGenerator": LLMToolkitTextGenerator, # Keep original
|
||
"LLMToolkitTextGeneratorStream": LLMToolkitTextGeneratorStream # Add streaming version
|
||
}
|
||
|
||
NODE_DISPLAY_NAME_MAPPINGS = {
|
||
"LLMToolkitTextGenerator": "Generate Text (🔗LLMToolkit)",
|
||
"LLMToolkitTextGeneratorStream": "Generate Text Stream (🔗LLMToolkit)" # New display name
|
||
} |