diff --git a/custom_nodes/llm-toolkit/comfy-nodes/generate_text.py b/custom_nodes/llm-toolkit/comfy-nodes/generate_text.py new file mode 100644 index 000000000..103da865c --- /dev/null +++ b/custom_nodes/llm-toolkit/comfy-nodes/generate_text.py @@ -0,0 +1,1262 @@ +# generate_text.py +import os +import sys +import json +import base64 +import asyncio +import logging +from typing import Dict, Any, List, Optional, Union, Tuple, AsyncGenerator + +# Set up logging +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +# Add parent directory to path to ensure imports work correctly +current_dir = os.path.dirname(os.path.abspath(__file__)) +parent_dir = os.path.dirname(current_dir) +if parent_dir not in sys.path: + sys.path.insert(0, parent_dir) + +# Try to import ComfyUI-specific modules +try: + import folder_paths + from server import PromptServer # <-- Import PromptServer +except ImportError: + logger.warning("Could not import folder_paths or server. Make sure ComfyUI environment is set up.") + folder_paths = None + PromptServer = None + +# Check for required dependencies +missing_deps = [] +try: + import aiohttp +except ImportError: + missing_deps.append("aiohttp") + +if missing_deps: + logger.warning(f"Missing dependencies: {', '.join(missing_deps)}. Some functionality may not work.") + logger.warning("Please install missing dependencies: pip install " + " ".join(missing_deps)) + +# Import utility functions (assuming they exist) +from send_request import send_request, run_async, is_gpt5_model # Keep non-streaming version if needed +from api.openai_api import send_openai_responses_stream +from api.gemini_api import send_gemini_request_stream +from llmtoolkit_utils import query_local_ollama_models, ensure_ollama_server, ensure_ollama_model, get_api_key + +# Local transformers streaming (optional) - removed unused import +# The send_transformers_request_stream function was imported but never used +send_transformers_request_stream = None # type: ignore + +# Payload helper to embed context into a string subclass +from context_payload import ContextPayload + +# ----------------------------------------------------------------------------- +# Helpers: Lazy OpenCV import and video -> base64 frame extraction +# ----------------------------------------------------------------------------- +_cv2_ref = None + +def _get_cv2(): + global _cv2_ref + if _cv2_ref is None: + try: + import cv2 as _cv2 # type: ignore + _cv2_ref = _cv2 + except Exception: + _cv2_ref = None + logger.warning("cv2 not available. Video file frame extraction disabled.") + return _cv2_ref + + +def _is_video_file(path: str) -> bool: + try: + ext = os.path.splitext(path)[1].lower() + except Exception: + return False + return ext in {".mp4", ".mov", ".mkv", ".avi", ".webm", ".m4v"} + + +def _extract_video_file_frames_as_b64( + video_path: str, + max_frames: int = 5, + stride: int = 16, +) -> List[str]: + """Extract up to `max_frames` JPEG base64 frames from a local video file. + + - Uses a simple stride to subsample frames. + - Falls back to the first N frames if the video is very short. + - Returns an empty list on any error or if cv2 is unavailable. + """ + cv2 = _get_cv2() + if cv2 is None: + return [] + + try: + cap = cv2.VideoCapture(video_path) + if not cap.isOpened(): + logger.warning("Could not open video file: %s", video_path) + return [] + + frames: List[str] = [] + frame_index = 0 + picked = 0 + + # Try stride sampling first + while picked < max_frames: + cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index) + ok, frame = cap.read() + if not ok: + break + success, buf = cv2.imencode(".jpg", frame) + if success: + frames.append(base64.b64encode(buf.tobytes()).decode("ascii")) + picked += 1 + frame_index += stride + + # If we got nothing with stride, try the first sequential frames + if not frames: + cap.set(cv2.CAP_PROP_POS_FRAMES, 0) + count = 0 + while count < max_frames: + ok, frame = cap.read() + if not ok: + break + success, buf = cv2.imencode(".jpg", frame) + if success: + frames.append(base64.b64encode(buf.tobytes()).decode("ascii")) + count += 1 + + cap.release() + if frames: + logger.debug("Extracted %d frame(s) from video file %s", len(frames), video_path) + return frames + except Exception as e: + logger.warning("Error extracting frames from %s: %s", video_path, e, exc_info=True) + return [] + +# --- NEW STREAMING REQUEST FUNCTION (Example for Ollama) --- +# IMPORTANT: This needs to be adapted based on the actual API structure of the provider! +async def send_request_stream( + llm_provider: str, + base_ip: str, + port: str, + llm_model: str, + system_message: str, + user_message: str, + messages: List[Dict[str, str]], + seed: Optional[int] = None, + temperature: float = 0.7, + max_tokens: int = 1024, + random: bool = False, + top_k: int = 40, + top_p: float = 0.9, + repeat_penalty: float = 1.1, + stop: Optional[List[str]] = None, + keep_alive: Union[bool, str] = True, + llm_api_key: Optional[str] = None, + timeout: int = 120, # Add a timeout for the connection + base64_images: Optional[List[str]] = None, +) -> AsyncGenerator[str, None]: + """ + Sends a streaming request to an LLM provider (Example for Ollama). + Yields text chunks as they are received. + """ + provider_lower = llm_provider.lower() + + if provider_lower in ["openai", "openrouter"]: + # --- OpenAI & OpenRouter Specific Streaming Logic --- + if not llm_api_key: + logger.error(f"{llm_provider} streaming requested but no API key supplied.") + yield f"[{llm_provider} Error: API key missing]" + return + + # Check if this is a GPT-5 model and use Responses API (OpenAI only) + if provider_lower == "openai" and is_gpt5_model(llm_model): + logger.info(f"Detected GPT-5 model: {llm_model}, using Responses API for streaming") + try: + async for chunk in send_openai_responses_stream( + api_url="https://api.openai.com/v1/responses", + base64_images=base64_images, + model=llm_model, + system_message=system_message, + user_message=user_message, + messages=messages, + api_key=llm_api_key, + temperature=temperature, + max_tokens=max_tokens, + top_p=top_p, + ): + yield chunk + return + except Exception as e: + logger.warning(f"GPT-5 Responses stream failed: {e}, falling back to Chat Completions") + # Fall through to regular OpenAI streaming below + + api_url = "https://openrouter.ai/api/v1/chat/completions" if provider_lower == "openrouter" else "https://api.openai.com/v1/chat/completions" + + headers = { + "Authorization": f"Bearer {llm_api_key}", + "Content-Type": "application/json", + } + # Build message list if not provided + if not messages: + messages = [] + if system_message: + messages.append({"role": "system", "content": system_message}) + + # Handle multimodal user content (text + images) + if base64_images: + content_blocks = [] + if user_message: + content_blocks.append({"type": "text", "text": user_message}) + for img_b64 in base64_images: + content_blocks.append({ + "type": "image_url", + "image_url": {"url": f"data:image/jpeg;base64,{img_b64}"}, + }) + messages.append({"role": "user", "content": content_blocks}) + else: + if user_message: + messages.append({"role": "user", "content": user_message}) + + payload = { + "model": llm_model, + "messages": messages, + "temperature": temperature, + "max_tokens": max_tokens, + "top_p": top_p, + "stream": True, + } + # Remove None values + payload = {k: v for k, v in payload.items() if v is not None} + + logger.info(f"Streaming request to {llm_provider}: model={llm_model}") + session = None + try: + # Create session with custom connector for Windows + connector = aiohttp.TCPConnector(force_close=True) if sys.platform == 'win32' else None + session = aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=timeout), + connector=connector + ) + + async with session.post(api_url, headers=headers, json=payload) as response: + response.raise_for_status() + async for raw_line in response.content: + if not raw_line: + continue + line = raw_line.decode("utf-8").strip() + if not line: + continue + # OpenAI streams multiple lines that may begin with 'data:'; join those if needed. + if line.startswith("data: "): + data_str = line[len("data: ") :].strip() + else: + data_str = line + if data_str == "[DONE]": + break + try: + data_json = json.loads(data_str) + choices = data_json.get("choices", []) + if choices: + delta = choices[0].get("delta", {}) + content_piece = delta.get("content") + if content_piece: + yield content_piece + except json.JSONDecodeError: + logger.warning(f"Could not decode JSON line from OpenAI stream: {data_str}") + except Exception as e: + logger.error(f"Error during OpenAI streaming: {e}", exc_info=True) + yield f"[OpenAI streaming error: {e}]" + finally: + # Ensure session is properly closed + if session and not session.closed: + await session.close() + # Small delay to allow cleanup on Windows + if sys.platform == 'win32': + await asyncio.sleep(0.1) + return + + if provider_lower == "gemini": + if not llm_api_key: + logger.error("Gemini streaming requested but no API key supplied.") + yield "[Gemini Error: API key missing]" + return + + async for chunk in send_gemini_request_stream( + api_key=llm_api_key, + model=llm_model, + system_message=system_message, + user_message=user_message, + messages=messages, + base64_images=base64_images, + temperature=temperature, + max_tokens=max_tokens, + top_p=top_p, + top_k=top_k, + ): + yield chunk + return + + if provider_lower == "ollama": + # --- Ollama Specific Streaming Logic --- + if not ensure_ollama_server(base_ip, port): + logger.error("Ollama daemon unavailable and could not be started – aborting stream.") + yield "[Error: Ollama daemon unavailable]" + return + + # Ensure requested model is present locally (will pull if missing) + ensure_ollama_model(llm_model, base_ip, port) + + # Ollama expects plain base64 strings without the 'data:image/...' prefix + if base64_images: + base64_images = [ + img.split("base64,")[1] if isinstance(img, str) and "base64," in img else img + for img in base64_images + ] + + # Decide which Ollama endpoint to use: + # • /api/generate – fast text-only streaming (no image support) + # • /api/chat – full chat/completions with image support + # We switch to /api/chat automatically if the caller supplied base64-encoded images + # so that vision models (e.g. qwen-vl, llava) receive the image bytes. + + use_chat_endpoint = bool(base64_images) # True if we have images to send + + url = ( + f"http://{base_ip}:{port}/api/chat" if use_chat_endpoint else f"http://{base_ip}:{port}/api/generate" + ) + + headers = {"Content-Type": "application/json"} + if llm_api_key: # Ollama doesn't typically use API keys this way, but include for consistency + headers["Authorization"] = f"Bearer {llm_api_key}" + + # ------------------------------------------------------------------ + # Build request payloads + # ------------------------------------------------------------------ + + if use_chat_endpoint: + # -------------------------------------------------------------- + # /api/chat (supports multimodal, messages array) + # -------------------------------------------------------------- + if not messages: + messages = [] + if system_message: + messages.append({"role": "system", "content": system_message}) + + # Always append the user message at the end so that vision models + # get the most recent prompt + images in a single message. + user_msg: dict[str, Any] = {"role": "user", "content": user_message or ""} + if base64_images: + user_msg["images"] = base64_images # Ollama expects a list key called "images" + messages.append(user_msg) + + payload = { + "model": llm_model, + "messages": messages, + "stream": True, + "options": { + "seed": seed, + "temperature": temperature, + "num_predict": max_tokens, + "top_k": top_k, + "top_p": top_p, + "repeat_penalty": repeat_penalty, + "stop": stop, + }, + } + else: + # -------------------------------------------------------------- + # /api/generate (text-only) + # -------------------------------------------------------------- + # Construct messages list if not provided directly (for context) + if not messages: + messages = [] + if system_message: + messages.append({"role": "system", "content": system_message}) + if user_message: + messages.append({"role": "user", "content": user_message}) + + payload = { + "model": llm_model, + "prompt": user_message, + "system": system_message if system_message else None, + "stream": True, + "options": { + "seed": seed, + "temperature": temperature, + "num_predict": max_tokens, + "top_k": top_k, + "top_p": top_p, + "repeat_penalty": repeat_penalty, + "stop": stop, + }, + } + # Clean up None values Ollama might not like + if not system_message: + del payload["system"] + + # Remove None values from options dict + if "options" in payload: + payload["options"] = {k: v for k, v in payload["options"].items() if v is not None} + + logger.info(f"Streaming request to Ollama: {url} with payload: {{'model': '{llm_model}', 'stream': True, 'use_chat': {use_chat_endpoint}}}") + + try: + # Use a single session if possible, manage timeouts + async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=timeout)) as session: + async with session.post(url, headers=headers, json=payload) as response: + response.raise_for_status() # Raise an exception for bad status codes (4xx or 5xx) + + # Process the streaming response line by line. The JSON schema differs slightly + # between /api/generate and /api/chat, so we branch inside the loop. + async for line in response.content: + if not line: + continue + try: + decoded_line = line.decode("utf-8").strip() + if not decoded_line: + continue + + data = json.loads(decoded_line) + + if use_chat_endpoint: + # Chat endpoint returns {"message": {"content": "..."}, "done": bool} + msg = data.get("message", {}) + chunk = msg.get("content", "") + is_done = data.get("done", False) + else: + # Generate endpoint returns {"response": "...", "done": bool} + chunk = data.get("response", "") + is_done = data.get("done", False) + + if chunk: + yield chunk + + if is_done: + logger.info("Ollama stream finished.") + break + except json.JSONDecodeError: + logger.warning( + f"Could not decode JSON line from Ollama stream: {line.decode('utf-8', errors='ignore')}" + ) + except Exception as e: + logger.error(f"Error processing Ollama stream line: {e}", exc_info=True) + yield f"[Error processing stream: {e}]" + break # Stop streaming on error + + except aiohttp.ClientConnectorError as e: + logger.error(f"Connection error to {url}: {e}") + yield f"[Connection Error: {e}]" + except aiohttp.ClientResponseError as e: + logger.error(f"HTTP error {e.status} from {url}: {e.message}") + # Attempt to read error details from response body + try: + error_body = await e.response.text() if hasattr(e, 'response') else 'No details' + logger.error(f"Error Body: {error_body}") + yield f"[HTTP Error {e.status}: {e.message} - {error_body[:100]}]" + except: + yield f"[HTTP Error {e.status}: {e.message}]" + except asyncio.TimeoutError: + logger.error(f"Request timed out after {timeout} seconds to {url}") + yield f"[Timeout Error]" + except ConnectionResetError as e: + logger.warning(f"Connection reset by peer: {e}") + yield f"[Connection Reset: The server closed the connection]" + except OSError as e: + if e.errno == 10054: # Windows specific: connection forcibly closed + logger.warning("Connection forcibly closed by remote host") + yield f"[Connection closed by server]" + else: + logger.error(f"OS Error during streaming: {e}") + yield f"[Network Error: {e}]" + except Exception as e: + logger.error(f"An unexpected error occurred during streaming request: {e}", exc_info=True) + yield f"[Unexpected Error: {e}]" + + if provider_lower == "groq": + # --- Groq Specific Streaming Logic --- + if not llm_api_key: + logger.error("Groq streaming requested but no API key supplied.") + # Fallback could be added here if desired, for now just yield error + yield "[Groq Error: API key missing]" + return + + groq_url = "https://api.groq.com/openai/v1/chat/completions" + headers = { + "Authorization": f"Bearer {llm_api_key}", + "Content-Type": "application/json", + } + + # Handle multimodal user content (text + images) + is_vision_model = "scout" in llm_model or "maverick" in llm_model + images_to_send = base64_images + if images_to_send and not is_vision_model: + logger.warning("Groq stream: Model '%s' may not support images. Sending without.", llm_model) + images_to_send = None + elif images_to_send and len(images_to_send) > 5: + logger.warning("Groq stream: Taking first 5 of %s images for vision model.", len(images_to_send)) + images_to_send = images_to_send[:5] + + # Build message list if not provided + if not messages: + messages = [] + if system_message: + messages.append({"role": "system", "content": system_message}) + + if images_to_send: + content_blocks = [{"type": "text", "text": user_message}] + for img_b64 in images_to_send: + content_blocks.append({ + "type": "image_url", + "image_url": {"url": f"data:image/jpeg;base64,{img_b64}"}, + }) + messages.append({"role": "user", "content": content_blocks}) + else: + if user_message: + messages.append({"role": "user", "content": user_message}) + + payload = { + "model": llm_model, + "messages": messages, + "temperature": temperature, + "max_completion_tokens": max_tokens, # Use correct Groq param + "top_p": top_p, + "stream": True, + } + # Remove None values + payload = {k: v for k, v in payload.items() if v is not None} + + logger.info(f"Streaming request to Groq: model={llm_model}") + session = None + try: + # Create session with custom connector for Windows + connector = aiohttp.TCPConnector(force_close=True) if sys.platform == 'win32' else None + session = aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=timeout), + connector=connector + ) + + async with session.post(groq_url, headers=headers, json=payload) as response: + response.raise_for_status() + async for raw_line in response.content: + line = raw_line.decode("utf-8").strip() + if not line: continue + if line.startswith("data: "): + data_str = line[len("data: ") :].strip() + else: + data_str = line + if data_str == "[DONE]": + break + try: + data_json = json.loads(data_str) + choices = data_json.get("choices", []) + if choices: + delta = choices[0].get("delta", {}) + content_piece = delta.get("content") + if content_piece: + yield content_piece + except json.JSONDecodeError: + logger.warning(f"Could not decode JSON line from Groq stream: {data_str}") + except Exception as e: + logger.error(f"Error during Groq streaming: {e}", exc_info=True) + yield f"[Groq streaming error: {e}]" + finally: + # Ensure session is properly closed + if session and not session.closed: + await session.close() + # Small delay to allow cleanup on Windows + if sys.platform == 'win32': + await asyncio.sleep(0.1) + return + + # Existing fallback logic for other providers + if provider_lower not in ["ollama", "openai", "openrouter", "transformers", "hf", "local", "groq", "gemini"]: + logger.warning(f"Streaming not implemented for provider '{llm_provider}'. Falling back to non-streaming.") + try: + full_response_data = await send_request( + llm_provider=llm_provider, base_ip=base_ip, port=port, images=base64_images, llm_model=llm_model, + system_message=system_message, user_message=user_message, messages=messages, seed=seed, + temperature=temperature, max_tokens=max_tokens, random=random, top_k=top_k, top_p=top_p, + repeat_penalty=repeat_penalty, stop=stop, keep_alive=keep_alive, llm_api_key=llm_api_key + ) + if isinstance(full_response_data, dict): + if "choices" in full_response_data and full_response_data["choices"]: + message = full_response_data["choices"][0].get("message", {}) + content = message.get("content", "") + if content: yield content + elif "response" in full_response_data: + if full_response_data["response"]: yield full_response_data["response"] + elif "candidates" in full_response_data and full_response_data.get("candidates"): + try: + content = full_response_data["candidates"][0]["content"]["parts"][0]["text"] + if content: yield content + except (KeyError, IndexError, TypeError): + logger.warning(f"Could not parse Gemini response format: {full_response_data}") + else: + logger.error(f"Unexpected non-streaming response format: {full_response_data}") + elif isinstance(full_response_data, str): + if full_response_data: yield full_response_data + else: + logger.error(f"Unexpected non-streaming response type: {type(full_response_data)}") + + except Exception as e: + logger.error(f"Error in fallback non-streaming request for {llm_provider}: {e}", exc_info=True) + yield f"[Error: {e}]" + return # Stop generation after yielding the fallback response + +def _remove_thinking_tags(text: str) -> str: + """Remove ... blocks from text, including the tags themselves.""" + import re + # Pattern to match ... or ◁think▷...◁/think▷ + pattern = r'.*?|◁think▷.*?◁/think▷' + cleaned = re.sub(pattern, '', text, flags=re.DOTALL) + # Clean up any extra whitespace/newlines left behind + cleaned = re.sub(r'\n\s*\n\s*\n', '\n\n', cleaned) # Replace multiple newlines with double + return cleaned.strip() + +# --- Original Node (for reference or non-streaming use) --- +class LLMToolkitTextGenerator: + DEFAULT_PROVIDER = "openai" + + DEFAULT_MODEL: str = "gpt-4o-mini" + + MODEL_LIST: List[str] = [DEFAULT_MODEL] + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "prompt": ("STRING", {"multiline": False, "default": "Write a short story about a robot learning to paint."}), + "hide_thinking": ("BOOLEAN", {"default": True, "tooltip": "Hide model thinking process (content between tags)"}) + }, + "optional": { + "context": ("*", {}), + "temperature": ("FLOAT", {"default": 0.7, "min": 0.0, "max": 2.0, "step": 0.05}), + "max_tokens": ("INT", {"default": 1024, "min": 1, "max": 65536}), + "top_p": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 1.0, "step": 0.05}), + "top_k": ("INT", {"default": 40, "min": 0, "max": 500}), + "seed": ("INT", {"default": -1, "min": -1, "max": 2147483647, "tooltip": "-1 = random seed"}) + }, + "hidden": { + "llm_model": ("STRING", {"default": cls.DEFAULT_MODEL}) + } + } + + RETURN_TYPES = ("*", "STRING") + RETURN_NAMES = ("context", "text") + FUNCTION = "generate" + CATEGORY = "🔗llm_toolkit/generators" + OUTPUT_NODE = True # Keeps the text widget for non-streaming version + + def generate(self, prompt, hide_thinking, context=None, temperature=0.7, max_tokens=1024, top_p=0.9, top_k=40, seed=-1, llm_model=None): + # ... (original generate logic using run_async(send_request(...))) ... + # This function remains mostly the same as the user provided, + # calling the original non-streaming send_request. + # We'll copy the parameter processing logic from the streaming version + # for consistency, but it will call the non-streaming send_request. + try: + # Base parameter defaults + params = { + "llm_provider": self.DEFAULT_PROVIDER, + "llm_model": llm_model or self.DEFAULT_MODEL, + "system_message": "You are a helpful, creative, and concise assistant.", + "user_message": prompt, + "base_ip": "localhost", + "port": "11434", + "temperature": temperature, + "max_tokens": max_tokens, + "top_p": top_p, + "top_k": top_k, + "seed": seed if seed != -1 else None, + "repeat_penalty": 1.1, + "stop": None, + "keep_alive": True, + "messages": [], + } + + provider_config = None + if context is not None: + if isinstance(context, dict) and "provider_name" in context: + provider_config = context + elif isinstance(context, dict) and "provider_config" in context: + provider_config = context["provider_config"] + + if provider_config and isinstance(provider_config, dict): + if "provider_name" in provider_config: params["llm_provider"] = provider_config["provider_name"] + if "api_key" in provider_config: params["llm_api_key"] = provider_config["api_key"] + if "base_ip" in provider_config: params["base_ip"] = provider_config["base_ip"] + if "port" in provider_config: params["port"] = provider_config["port"] + for key in provider_config: + if key not in ["provider_name", "llm_model", "api_key", "base_ip", "port", "user_prompt"]: + params[key] = provider_config[key] + if "user_prompt" in provider_config: params["user_message"] = provider_config["user_prompt"] + provider_model = provider_config.get("llm_model", "") + if provider_model: + params["llm_model"] = provider_model + else: + params["llm_model"] = "" # Will be handled below + + if params.get("llm_provider"): + provider = str(params["llm_provider"]).lower() + if not params.get("llm_model"): + PROVIDER_DEFAULTS = {"openai": "gpt-4o-mini", "anthropic": "claude-3-opus-20240229"} + fallback = PROVIDER_DEFAULTS.get(provider) + if fallback: params["llm_model"] = fallback + + # --- Groq reasoning-format handling --- + if provider == "groq": + # Hide thinking via API when requested to avoid post-processing. + params["reasoning_format"] = "hidden" if hide_thinking else "raw" + + # Auto-fetch API key for OpenAI if missing/placeholder + if provider == "openai" and (not params.get("llm_api_key") or params["llm_api_key"] in {"", "1234", None}): + try: + params["llm_api_key"] = get_api_key("OPENAI_API_KEY", "openai") + logger.info("generate: Retrieved OpenAI API key via get_api_key helper.") + except ValueError as _e: + logger.warning(f"generate: get_api_key failed – {_e}") + + # --- Prompt-Manager support --------------------------------------------------- + # If upstream PromptManager attached a prompt_config dict, use the data + prompt_cfg = None + if context is not None and isinstance(context, dict): + prompt_cfg = context.get("prompt_config") + + if prompt_cfg and isinstance(prompt_cfg, dict): + # Override user prompt text if provided + if prompt_cfg.get("text"): + params["user_message"] = prompt_cfg["text"] + + # Collect images and video frames (already in base64 from PromptManager) + imgs = [] + if prompt_cfg.get("image_base64"): + img_val = prompt_cfg["image_base64"] + if isinstance(img_val, str): + imgs.extend([img_val]) + elif isinstance(img_val, list): + imgs.extend(img_val) + + # Add extracted video frames + if prompt_cfg.get("video_frames_base64"): + vid_frames = prompt_cfg["video_frames_base64"] + if isinstance(vid_frames, list): + imgs.extend(vid_frames) + else: + imgs.append(vid_frames) + + params["images"] = imgs if imgs else None + + # Pass through file paths and URLs for APIs that support them + if prompt_cfg.get("file_paths"): + params["file_paths"] = prompt_cfg["file_paths"] + logger.debug(f"PromptManager: Passing file_paths to API (some APIs process videos/PDFs directly).") + + if prompt_cfg.get("urls"): + params["urls"] = prompt_cfg["urls"] + logger.debug(f"PromptManager: Passing URLs to API.") + + # Audio path remains separate + if prompt_cfg.get("audio_path"): + params["audio_path"] = prompt_cfg["audio_path"] + else: + params["images"] = None + + # --- New: If file_paths contain local video files, extract a few frames --- + try: + file_paths_value = params.get("file_paths") + candidate_paths: List[str] = [] + if isinstance(file_paths_value, str) and file_paths_value.strip(): + candidate_paths = [file_paths_value.strip()] + elif isinstance(file_paths_value, list): + candidate_paths = [p for p in file_paths_value if isinstance(p, str) and p.strip()] + + extracted_from_files: List[str] = [] + for p in candidate_paths: + if _is_video_file(p) and os.path.isfile(p): + extracted_from_files.extend(_extract_video_file_frames_as_b64(p)) + + if extracted_from_files: + if params.get("images") is None: + params["images"] = [] + if not isinstance(params["images"], list): + params["images"] = [params["images"]] # normalize + # cap total images to 8 to avoid huge payloads + remaining = max(0, 8 - len(params["images"])) + params["images"].extend(extracted_from_files[:remaining]) + logger.info("Added %d frame(s) extracted from video file paths to images payload.", min(len(extracted_from_files), remaining)) + except Exception as e: + logger.warning("Failed to process video file paths for frame extraction: %s", e, exc_info=True) + + log_params = _sanitize_params_for_log(params) + logger.info(f"[Non-Streaming] Making LLM request with params: {log_params}") + + try: + # --- CALL NON-STREAMING VERSION --- + response_data = run_async( + send_request( # Original non-streaming call + llm_provider=params["llm_provider"], + base_ip=params.get("base_ip", "localhost"), + port=params.get("port", "11434"), + images=params.get("images"), + llm_model=params["llm_model"], + system_message=params["system_message"], + user_message=params["user_message"], + messages=params["messages"], + seed=params.get("seed"), + temperature=params["temperature"], + max_tokens=params["max_tokens"], + random=params.get("random", False), + top_k=params["top_k"], + top_p=params["top_p"], + repeat_penalty=params["repeat_penalty"], + stop=params.get("stop"), + keep_alive=params.get("keep_alive", True), + llm_api_key=params.get("llm_api_key"), + reasoning_format=params.get("reasoning_format") # Pass reasoning_format + ) + ) + # --- END NON-STREAMING CALL --- + except Exception as e: + logger.error(f"Error in non-streaming send_request call: {e}", exc_info=True) + response_data = {"choices": [{"message": {"content": f"Error calling send_request: {str(e)}"}}]} + + if response_data is None: + content = "Error: Received None response" + elif isinstance(response_data, dict): + if "choices" in response_data and response_data["choices"]: + message = response_data["choices"][0].get("message", {}) + content = message.get("content", "") + if content is None: + content = "Error: Null content in response" + elif "response" in response_data: + content = response_data["response"] + else: + content = f"Error: Unexpected format: {str(response_data)}" + elif isinstance(response_data, str): + content = response_data + else: + content = f"Error: Unexpected type: {type(response_data)}" + + # Apply thinking tag removal if requested + if hide_thinking and content: + content = _remove_thinking_tags(content) + + if context is not None and isinstance(context, dict): + context_out = context.copy() + context_out["llm_response"] = content + context_out["llm_raw_response"] = response_data + else: + context_out = {"llm_response": content, "llm_raw_response": response_data, "passthrough_data": context} + + payload = ContextPayload(content, context_out) + return {"ui": {"string": [content]}, "result": (payload, str(payload))} + + except Exception as e: + error_message = f"Error generating text: {str(e)}" + logger.error(error_message, exc_info=True) + error_output = {"error": error_message, "original_input": context} + payload = ContextPayload(error_message, error_output) + return {"ui": {"string": [error_message]}, "result": (payload, str(payload))} + +# --- Helper to avoid dumping large base64 blobs to INFO log ----------------- +def _sanitize_params_for_log(d: dict) -> dict: + """Return a shallow copy with huge fields replaced by short placeholders.""" + out = {} + for k, v in d.items(): + # Hide large binary blobs (images, video frames) + if k in {"images", "video_frames", "base64_images"} and v: + # Replace with concise placeholder + if isinstance(v, (list, tuple)): + out[k] = f"[{len(v)} item(s) base64 omitted]" + elif isinstance(v, str): + out[k] = "[base64 string omitted]" + else: + out[k] = "[data omitted]" + continue + + # Mask any values that look like API keys so we never log them in full + key_lc = k.lower() + if "api_key" in key_lc or key_lc.endswith("_key") or key_lc.endswith("key"): + if isinstance(v, str) and v: + # Keep first 5 chars for debugging, mask the rest + masked = v[:5] + "…" if len(v) > 5 else "***" + out[k] = masked + else: + out[k] = "[key hidden]" + continue + + # Default passthrough + else: + out[k] = v + return out + +# --- NEW STREAMING NODE --- +class LLMToolkitTextGeneratorStream: + DEFAULT_PROVIDER = "openai" + + DEFAULT_MODEL: str = "gpt-4o-mini" + + MODEL_LIST: List[str] = [DEFAULT_MODEL] + + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "prompt": ("STRING", {"multiline": False, "default": "Write a detailed description of a futuristic city."}), + "hide_thinking": ("BOOLEAN", {"default": True, "tooltip": "Hide model thinking process (content between tags)"}) + }, + "optional": { + "context": ("*", {}), + "temperature": ("FLOAT", {"default": 0.7, "min": 0.0, "max": 2.0, "step": 0.05}), + "max_tokens": ("INT", {"default": 1024, "min": 1, "max": 65536}), + "top_p": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 1.0, "step": 0.05}), + "top_k": ("INT", {"default": 40, "min": 0, "max": 500}), + "seed": ("INT", {"default": -1, "min": -1, "max": 2147483647, "tooltip": "-1 = random seed"}) + }, + "hidden": { # <-- Add hidden inputs + "unique_id": "UNIQUE_ID", + "llm_model": ("STRING", {"default": cls.DEFAULT_MODEL}) + }, + } + + RETURN_TYPES = ("*", "STRING") + RETURN_NAMES = ("context", "text") + FUNCTION = "generate_stream" # <-- Use new function name + CATEGORY = "🔗llm_toolkit/generators" + OUTPUT_NODE = True # Keep the JS widget logic + + def generate_stream(self, prompt, hide_thinking, unique_id, llm_model=None, context=None, **kwargs): + """ + Generates text using the specified provider and streams the response back + to the UI via websocket messages. (synchronous wrapper) + """ + # Wrap the previous async implementation inside an inner coroutine + async def _async_generate(): + # Previous async body START + if PromptServer is None: + logger.error("PromptServer not available. Cannot stream.") + error_msg = "Streaming requires PromptServer, which is not available." + error_output = {"error": error_msg, "original_input": context} + payload = ContextPayload(error_msg, error_output) + return {"ui": {"string": [error_msg]}, "result": (payload, str(payload))} + + server = PromptServer.instance + full_response_text = "" + thinking_buffer = "" + inside_thinking = False + + try: + # --- Corrected Parameter Processing Logic --- + # 1. Start with node defaults + params = { + "llm_provider": self.DEFAULT_PROVIDER, + "llm_model": llm_model or self.DEFAULT_MODEL, + "system_message": "You are a helpful, creative, and concise assistant.", + "user_message": prompt, + "base_ip": "localhost", "port": "11434", + "temperature": kwargs.get('temperature', 0.7), "max_tokens": kwargs.get('max_tokens', 1024), "top_p": kwargs.get('top_p', 0.9), "top_k": kwargs.get('top_k', 40), "seed": kwargs.get('seed', -1), + "repeat_penalty": 1.1, "stop": None, "keep_alive": "5m", + "messages": [], "llm_api_key": None, + } + + # 2. Apply general settings from the incoming context + if context and isinstance(context, dict): + for key, value in context.items(): + if key in params and value is not None: + params[key] = value + + # 3. Apply specific provider_config settings, which take precedence + provider_config = None + if context and isinstance(context, dict): + if "provider_config" in context and isinstance(context["provider_config"], dict): + provider_config = context["provider_config"] + elif "provider_name" in context: # Handle flat context from old nodes + provider_config = context + + if provider_config: + # First, update all matching keys from provider_config that also exist in params + params.update({k: v for k, v in provider_config.items() if k in params and v is not None}) + + # Second, handle specific key name differences. This ensures provider_name is mapped correctly. + if "provider_name" in provider_config and provider_config["provider_name"]: + params["llm_provider"] = provider_config["provider_name"] + elif context and isinstance(context, dict) and context.get("provider_name"): + # Fallback to root-level key if nested provider_config didn't have it + params["llm_provider"] = context["provider_name"] + + if "api_key" in provider_config and provider_config["api_key"]: + params["llm_api_key"] = provider_config["api_key"] + elif context and isinstance(context, dict) and context.get("api_key"): + params["llm_api_key"] = context["api_key"] + + if "llm_model" in provider_config and provider_config["llm_model"]: + params["llm_model"] = provider_config["llm_model"] + elif context and isinstance(context, dict) and context.get("llm_model"): + params["llm_model"] = context["llm_model"] + + # 4. The user_message from context or provider_config can be used, + # but the node's prompt input has the final say. + if context and isinstance(context, dict) and "user_prompt" in context and context["user_prompt"]: + params["user_message"] = context.get("user_prompt") + if provider_config and "user_prompt" in provider_config and provider_config["user_prompt"]: + params["user_message"] = provider_config.get("user_prompt") + if prompt: # Node input is final override, if provided + params["user_message"] = prompt + + # --- End Corrected Parameter Processing --- + + # Finalize model name fallback + provider = "unknown" + if params.get("llm_provider"): + provider = str(params["llm_provider"]).lower() + if not params.get("llm_model"): + PROVIDER_DEFAULTS = {"openai": "gpt-4o-mini", "anthropic": "claude-3-opus-20240229"} + fallback = PROVIDER_DEFAULTS.get(provider) + params["llm_model"] = fallback or self.DEFAULT_MODEL + + # Groq reasoning format + if provider == "groq": + params["reasoning_format"] = "hidden" if hide_thinking else "raw" + + # Auto-fetch API key for OpenAI if missing/placeholder + if provider == "openai" and (not params.get("llm_api_key") or params["llm_api_key"] in {"", "1234", None}): + try: + params["llm_api_key"] = get_api_key("OPENAI_API_KEY", "openai") + logger.info("generate_stream: Retrieved OpenAI API key via get_api_key helper.") + except ValueError as _e: + logger.warning(f"generate_stream: get_api_key failed – {_e}") + + # --- Prompt-Manager support --------------------------------------------------- + # If upstream PromptManager attached a prompt_config dict, use the data + prompt_cfg = None + if context is not None and isinstance(context, dict): + prompt_cfg = context.get("prompt_config") + + if prompt_cfg and isinstance(prompt_cfg, dict): + # Override user prompt text if provided + if prompt_cfg.get("text"): + params["user_message"] = prompt_cfg["text"] + + # Collect images and video frames (already in base64 from PromptManager) + imgs = [] + if prompt_cfg.get("image_base64"): + img_val = prompt_cfg["image_base64"] + if isinstance(img_val, str): + imgs.extend([img_val]) + elif isinstance(img_val, list): + imgs.extend(img_val) + + # Add extracted video frames + if prompt_cfg.get("video_frames_base64"): + vid_frames = prompt_cfg["video_frames_base64"] + if isinstance(vid_frames, list): + imgs.extend(vid_frames) + else: + imgs.append(vid_frames) + + params["images"] = imgs if imgs else None + + # Pass through file paths and URLs for APIs that support them + if prompt_cfg.get("file_paths"): + params["file_paths"] = prompt_cfg["file_paths"] + logger.debug(f"PromptManager: Passing file_paths to API (some APIs process videos/PDFs directly).") + + if prompt_cfg.get("urls"): + params["urls"] = prompt_cfg["urls"] + logger.debug(f"PromptManager: Passing URLs to API.") + + # Audio path remains separate + if prompt_cfg.get("audio_path"): + params["audio_path"] = prompt_cfg["audio_path"] + else: + params["images"] = None + + # --- New: If file_paths contain local video files, extract a few frames --- + try: + file_paths_value = params.get("file_paths") + candidate_paths: List[str] = [] + if isinstance(file_paths_value, str) and file_paths_value.strip(): + candidate_paths = [file_paths_value.strip()] + elif isinstance(file_paths_value, list): + candidate_paths = [p for p in file_paths_value if isinstance(p, str) and p.strip()] + + extracted_from_files: List[str] = [] + for p in candidate_paths: + if _is_video_file(p) and os.path.isfile(p): + extracted_from_files.extend(_extract_video_file_frames_as_b64(p)) + + if extracted_from_files: + if params.get("images") is None: + params["images"] = [] + if not isinstance(params["images"], list): + params["images"] = [params["images"]] # normalize + # cap total images to 8 to avoid huge payloads + remaining = max(0, 8 - len(params["images"])) + params["images"].extend(extracted_from_files[:remaining]) + logger.info( + "[Streaming] Added %d frame(s) extracted from video file paths to images payload.", + min(len(extracted_from_files), remaining), + ) + except Exception as e: + logger.warning( + "[Streaming] Failed to process video file paths for frame extraction: %s", + e, + exc_info=True, + ) + + log_params = _sanitize_params_for_log(params) + logger.info( + f"[Streaming] Initiating LLM stream with params: {log_params} for node {unique_id}" + ) + + # --- Send START message --- + server.send_sync("llmtoolkit.stream.start", {"node": unique_id}, sid=server.client_id) + + # --- Initiate and process the stream --- + stream_generator = send_request_stream( + llm_provider=params["llm_provider"], + base_ip=params.get("base_ip", "localhost"), + port=params.get("port", "11434"), + llm_model=params["llm_model"], + system_message=params["system_message"], + user_message=params["user_message"], + messages=params["messages"], + seed=params.get("seed"), + temperature=params["temperature"], + max_tokens=params["max_tokens"], + random=params.get("random", False), + top_k=params["top_k"], + top_p=params["top_p"], + repeat_penalty=params["repeat_penalty"], + stop=params.get("stop"), + keep_alive=params.get("keep_alive", True), + llm_api_key=params.get("llm_api_key"), + base64_images=params.get("images"), + ) + + async for chunk in stream_generator: + if chunk: + full_response_text += chunk + + # Handle thinking tag filtering for streaming + if hide_thinking: + # Track thinking state and buffer content + chunk_to_send = "" + i = 0 + while i < len(chunk): + + if not inside_thinking: + # Check for start of thinking tag + if chunk[i:i+7] == '': + inside_thinking = True + thinking_buffer = '' + i += 7 + continue + elif chunk[i:i+9] == '◁think▷': # <-- Add Kimi-VL start tag + inside_thinking = True + thinking_buffer = '◁think▷' + i += 9 + continue + else: + chunk_to_send += chunk[i] + else: + # Inside thinking block, buffer until we find closing tag + thinking_buffer += chunk[i] + if thinking_buffer.endswith(''): + inside_thinking = False + thinking_buffer = "" + i += 1 + continue + elif thinking_buffer.endswith('◁/think▷'): # <-- Add Kimi-VL end tag + inside_thinking = False + thinking_buffer = "" + i += 1 + continue + i += 1 + + # Only send non-thinking content + if chunk_to_send: + server.send_sync( + "llmtoolkit.stream.chunk", + {"node": unique_id, "text": chunk_to_send}, + sid=server.client_id, + ) + else: + # Send all content if not hiding thinking + server.send_sync( + "llmtoolkit.stream.chunk", + {"node": unique_id, "text": chunk}, + sid=server.client_id, + ) + await asyncio.sleep(0.001) + + # Apply thinking tag removal to final text if requested + final_text = full_response_text + if hide_thinking: + final_text = _remove_thinking_tags(full_response_text) + + logger.info( + f"[Streaming] Finished for node {unique_id}. Total length: {len(final_text)}" + ) + server.send_sync( + "llmtoolkit.stream.end", + {"node": unique_id, "final_text": final_text}, + sid=server.client_id, + ) + + # --- Prepare final context output --- + if context is not None and isinstance(context, dict): + context_out = context.copy() + context_out["llm_response"] = final_text + context_out["llm_raw_response"] = { + "status": "Streamed successfully", + "final_length": len(final_text), + } + else: + context_out = { + "llm_response": final_text, + "llm_raw_response": { + "status": "Streamed successfully", + "final_length": len(final_text), + }, + "passthrough_data": context, + } + + payload = ContextPayload(final_text, context_out) + return {"ui": {"string": [final_text]}, "result": (payload, str(payload))} + + except Exception as e: + error_message = f"Error during streaming generation: {str(e)}" + logger.error(error_message, exc_info=True) + if server and unique_id: + server.send_sync( + "llmtoolkit.stream.error", + {"node": unique_id, "error": error_message}, + sid=server.client_id, + ) + error_output = { + "error": error_message, + "partial_response": full_response_text, + "original_input": context, + } + payload = ContextPayload(error_message, error_output) + return { + "ui": {"string": [f"Error: {error_message}\nPartial: {full_response_text}"]}, + "result": (payload, str(payload)), + } + # Previous async body END + + # Execute the inner coroutine and return its result synchronously + return run_async(_async_generate()) + + +# Node Mappings for ComfyUI +NODE_CLASS_MAPPINGS = { + "LLMToolkitTextGenerator": LLMToolkitTextGenerator, # Keep original + "LLMToolkitTextGeneratorStream": LLMToolkitTextGeneratorStream # Add streaming version +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "LLMToolkitTextGenerator": "Generate Text (🔗LLMToolkit)", + "LLMToolkitTextGeneratorStream": "Generate Text Stream (🔗LLMToolkit)" # New display name +} \ No newline at end of file