# generate_text.py import os import sys import json import base64 import asyncio import logging from typing import Dict, Any, List, Optional, Union, Tuple, AsyncGenerator # Set up logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) # Add parent directory to path to ensure imports work correctly current_dir = os.path.dirname(os.path.abspath(__file__)) parent_dir = os.path.dirname(current_dir) if parent_dir not in sys.path: sys.path.insert(0, parent_dir) # Try to import ComfyUI-specific modules try: import folder_paths from server import PromptServer # <-- Import PromptServer except ImportError: logger.warning("Could not import folder_paths or server. Make sure ComfyUI environment is set up.") folder_paths = None PromptServer = None # Check for required dependencies missing_deps = [] try: import aiohttp except ImportError: missing_deps.append("aiohttp") if missing_deps: logger.warning(f"Missing dependencies: {', '.join(missing_deps)}. Some functionality may not work.") logger.warning("Please install missing dependencies: pip install " + " ".join(missing_deps)) # Import utility functions (assuming they exist) from send_request import send_request, run_async, is_gpt5_model # Keep non-streaming version if needed from api.openai_api import send_openai_responses_stream from api.gemini_api import send_gemini_request_stream from llmtoolkit_utils import query_local_ollama_models, ensure_ollama_server, ensure_ollama_model, get_api_key # Local transformers streaming (optional) - removed unused import # The send_transformers_request_stream function was imported but never used send_transformers_request_stream = None # type: ignore # Payload helper to embed context into a string subclass from context_payload import ContextPayload # ----------------------------------------------------------------------------- # Helpers: Lazy OpenCV import and video -> base64 frame extraction # ----------------------------------------------------------------------------- _cv2_ref = None def _get_cv2(): global _cv2_ref if _cv2_ref is None: try: import cv2 as _cv2 # type: ignore _cv2_ref = _cv2 except Exception: _cv2_ref = None logger.warning("cv2 not available. Video file frame extraction disabled.") return _cv2_ref def _is_video_file(path: str) -> bool: try: ext = os.path.splitext(path)[1].lower() except Exception: return False return ext in {".mp4", ".mov", ".mkv", ".avi", ".webm", ".m4v"} def _extract_video_file_frames_as_b64( video_path: str, max_frames: int = 5, stride: int = 16, ) -> List[str]: """Extract up to `max_frames` JPEG base64 frames from a local video file. - Uses a simple stride to subsample frames. - Falls back to the first N frames if the video is very short. - Returns an empty list on any error or if cv2 is unavailable. """ cv2 = _get_cv2() if cv2 is None: return [] try: cap = cv2.VideoCapture(video_path) if not cap.isOpened(): logger.warning("Could not open video file: %s", video_path) return [] frames: List[str] = [] frame_index = 0 picked = 0 # Try stride sampling first while picked < max_frames: cap.set(cv2.CAP_PROP_POS_FRAMES, frame_index) ok, frame = cap.read() if not ok: break success, buf = cv2.imencode(".jpg", frame) if success: frames.append(base64.b64encode(buf.tobytes()).decode("ascii")) picked += 1 frame_index += stride # If we got nothing with stride, try the first sequential frames if not frames: cap.set(cv2.CAP_PROP_POS_FRAMES, 0) count = 0 while count < max_frames: ok, frame = cap.read() if not ok: break success, buf = cv2.imencode(".jpg", frame) if success: frames.append(base64.b64encode(buf.tobytes()).decode("ascii")) count += 1 cap.release() if frames: logger.debug("Extracted %d frame(s) from video file %s", len(frames), video_path) return frames except Exception as e: logger.warning("Error extracting frames from %s: %s", video_path, e, exc_info=True) return [] # --- NEW STREAMING REQUEST FUNCTION (Example for Ollama) --- # IMPORTANT: This needs to be adapted based on the actual API structure of the provider! async def send_request_stream( llm_provider: str, base_ip: str, port: str, llm_model: str, system_message: str, user_message: str, messages: List[Dict[str, str]], seed: Optional[int] = None, temperature: float = 0.7, max_tokens: int = 1024, random: bool = False, top_k: int = 40, top_p: float = 0.9, repeat_penalty: float = 1.1, stop: Optional[List[str]] = None, keep_alive: Union[bool, str] = True, llm_api_key: Optional[str] = None, timeout: int = 120, # Add a timeout for the connection base64_images: Optional[List[str]] = None, ) -> AsyncGenerator[str, None]: """ Sends a streaming request to an LLM provider (Example for Ollama). Yields text chunks as they are received. """ provider_lower = llm_provider.lower() if provider_lower in ["openai", "openrouter"]: # --- OpenAI & OpenRouter Specific Streaming Logic --- if not llm_api_key: logger.error(f"{llm_provider} streaming requested but no API key supplied.") yield f"[{llm_provider} Error: API key missing]" return # Check if this is a GPT-5 model and use Responses API (OpenAI only) if provider_lower == "openai" and is_gpt5_model(llm_model): logger.info(f"Detected GPT-5 model: {llm_model}, using Responses API for streaming") try: async for chunk in send_openai_responses_stream( api_url="https://api.openai.com/v1/responses", base64_images=base64_images, model=llm_model, system_message=system_message, user_message=user_message, messages=messages, api_key=llm_api_key, temperature=temperature, max_tokens=max_tokens, top_p=top_p, ): yield chunk return except Exception as e: logger.warning(f"GPT-5 Responses stream failed: {e}, falling back to Chat Completions") # Fall through to regular OpenAI streaming below api_url = "https://openrouter.ai/api/v1/chat/completions" if provider_lower == "openrouter" else "https://api.openai.com/v1/chat/completions" headers = { "Authorization": f"Bearer {llm_api_key}", "Content-Type": "application/json", } # Build message list if not provided if not messages: messages = [] if system_message: messages.append({"role": "system", "content": system_message}) # Handle multimodal user content (text + images) if base64_images: content_blocks = [] if user_message: content_blocks.append({"type": "text", "text": user_message}) for img_b64 in base64_images: content_blocks.append({ "type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{img_b64}"}, }) messages.append({"role": "user", "content": content_blocks}) else: if user_message: messages.append({"role": "user", "content": user_message}) payload = { "model": llm_model, "messages": messages, "temperature": temperature, "max_tokens": max_tokens, "top_p": top_p, "stream": True, } # Remove None values payload = {k: v for k, v in payload.items() if v is not None} logger.info(f"Streaming request to {llm_provider}: model={llm_model}") session = None try: # Create session with custom connector for Windows connector = aiohttp.TCPConnector(force_close=True) if sys.platform == 'win32' else None session = aiohttp.ClientSession( timeout=aiohttp.ClientTimeout(total=timeout), connector=connector ) async with session.post(api_url, headers=headers, json=payload) as response: response.raise_for_status() async for raw_line in response.content: if not raw_line: continue line = raw_line.decode("utf-8").strip() if not line: continue # OpenAI streams multiple lines that may begin with 'data:'; join those if needed. if line.startswith("data: "): data_str = line[len("data: ") :].strip() else: data_str = line if data_str == "[DONE]": break try: data_json = json.loads(data_str) choices = data_json.get("choices", []) if choices: delta = choices[0].get("delta", {}) content_piece = delta.get("content") if content_piece: yield content_piece except json.JSONDecodeError: logger.warning(f"Could not decode JSON line from OpenAI stream: {data_str}") except Exception as e: logger.error(f"Error during OpenAI streaming: {e}", exc_info=True) yield f"[OpenAI streaming error: {e}]" finally: # Ensure session is properly closed if session and not session.closed: await session.close() # Small delay to allow cleanup on Windows if sys.platform == 'win32': await asyncio.sleep(0.1) return if provider_lower == "gemini": if not llm_api_key: logger.error("Gemini streaming requested but no API key supplied.") yield "[Gemini Error: API key missing]" return async for chunk in send_gemini_request_stream( api_key=llm_api_key, model=llm_model, system_message=system_message, user_message=user_message, messages=messages, base64_images=base64_images, temperature=temperature, max_tokens=max_tokens, top_p=top_p, top_k=top_k, ): yield chunk return if provider_lower == "ollama": # --- Ollama Specific Streaming Logic --- if not ensure_ollama_server(base_ip, port): logger.error("Ollama daemon unavailable and could not be started – aborting stream.") yield "[Error: Ollama daemon unavailable]" return # Ensure requested model is present locally (will pull if missing) ensure_ollama_model(llm_model, base_ip, port) # Ollama expects plain base64 strings without the 'data:image/...' prefix if base64_images: base64_images = [ img.split("base64,")[1] if isinstance(img, str) and "base64," in img else img for img in base64_images ] # Decide which Ollama endpoint to use: # • /api/generate – fast text-only streaming (no image support) # • /api/chat – full chat/completions with image support # We switch to /api/chat automatically if the caller supplied base64-encoded images # so that vision models (e.g. qwen-vl, llava) receive the image bytes. use_chat_endpoint = bool(base64_images) # True if we have images to send url = ( f"http://{base_ip}:{port}/api/chat" if use_chat_endpoint else f"http://{base_ip}:{port}/api/generate" ) headers = {"Content-Type": "application/json"} if llm_api_key: # Ollama doesn't typically use API keys this way, but include for consistency headers["Authorization"] = f"Bearer {llm_api_key}" # ------------------------------------------------------------------ # Build request payloads # ------------------------------------------------------------------ if use_chat_endpoint: # -------------------------------------------------------------- # /api/chat (supports multimodal, messages array) # -------------------------------------------------------------- if not messages: messages = [] if system_message: messages.append({"role": "system", "content": system_message}) # Always append the user message at the end so that vision models # get the most recent prompt + images in a single message. user_msg: dict[str, Any] = {"role": "user", "content": user_message or ""} if base64_images: user_msg["images"] = base64_images # Ollama expects a list key called "images" messages.append(user_msg) payload = { "model": llm_model, "messages": messages, "stream": True, "options": { "seed": seed, "temperature": temperature, "num_predict": max_tokens, "top_k": top_k, "top_p": top_p, "repeat_penalty": repeat_penalty, "stop": stop, }, } else: # -------------------------------------------------------------- # /api/generate (text-only) # -------------------------------------------------------------- # Construct messages list if not provided directly (for context) if not messages: messages = [] if system_message: messages.append({"role": "system", "content": system_message}) if user_message: messages.append({"role": "user", "content": user_message}) payload = { "model": llm_model, "prompt": user_message, "system": system_message if system_message else None, "stream": True, "options": { "seed": seed, "temperature": temperature, "num_predict": max_tokens, "top_k": top_k, "top_p": top_p, "repeat_penalty": repeat_penalty, "stop": stop, }, } # Clean up None values Ollama might not like if not system_message: del payload["system"] # Remove None values from options dict if "options" in payload: payload["options"] = {k: v for k, v in payload["options"].items() if v is not None} logger.info(f"Streaming request to Ollama: {url} with payload: {{'model': '{llm_model}', 'stream': True, 'use_chat': {use_chat_endpoint}}}") try: # Use a single session if possible, manage timeouts async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=timeout)) as session: async with session.post(url, headers=headers, json=payload) as response: response.raise_for_status() # Raise an exception for bad status codes (4xx or 5xx) # Process the streaming response line by line. The JSON schema differs slightly # between /api/generate and /api/chat, so we branch inside the loop. async for line in response.content: if not line: continue try: decoded_line = line.decode("utf-8").strip() if not decoded_line: continue data = json.loads(decoded_line) if use_chat_endpoint: # Chat endpoint returns {"message": {"content": "..."}, "done": bool} msg = data.get("message", {}) chunk = msg.get("content", "") is_done = data.get("done", False) else: # Generate endpoint returns {"response": "...", "done": bool} chunk = data.get("response", "") is_done = data.get("done", False) if chunk: yield chunk if is_done: logger.info("Ollama stream finished.") break except json.JSONDecodeError: logger.warning( f"Could not decode JSON line from Ollama stream: {line.decode('utf-8', errors='ignore')}" ) except Exception as e: logger.error(f"Error processing Ollama stream line: {e}", exc_info=True) yield f"[Error processing stream: {e}]" break # Stop streaming on error except aiohttp.ClientConnectorError as e: logger.error(f"Connection error to {url}: {e}") yield f"[Connection Error: {e}]" except aiohttp.ClientResponseError as e: logger.error(f"HTTP error {e.status} from {url}: {e.message}") # Attempt to read error details from response body try: error_body = await e.response.text() if hasattr(e, 'response') else 'No details' logger.error(f"Error Body: {error_body}") yield f"[HTTP Error {e.status}: {e.message} - {error_body[:100]}]" except: yield f"[HTTP Error {e.status}: {e.message}]" except asyncio.TimeoutError: logger.error(f"Request timed out after {timeout} seconds to {url}") yield f"[Timeout Error]" except ConnectionResetError as e: logger.warning(f"Connection reset by peer: {e}") yield f"[Connection Reset: The server closed the connection]" except OSError as e: if e.errno == 10054: # Windows specific: connection forcibly closed logger.warning("Connection forcibly closed by remote host") yield f"[Connection closed by server]" else: logger.error(f"OS Error during streaming: {e}") yield f"[Network Error: {e}]" except Exception as e: logger.error(f"An unexpected error occurred during streaming request: {e}", exc_info=True) yield f"[Unexpected Error: {e}]" if provider_lower == "groq": # --- Groq Specific Streaming Logic --- if not llm_api_key: logger.error("Groq streaming requested but no API key supplied.") # Fallback could be added here if desired, for now just yield error yield "[Groq Error: API key missing]" return groq_url = "https://api.groq.com/openai/v1/chat/completions" headers = { "Authorization": f"Bearer {llm_api_key}", "Content-Type": "application/json", } # Handle multimodal user content (text + images) is_vision_model = "scout" in llm_model or "maverick" in llm_model images_to_send = base64_images if images_to_send and not is_vision_model: logger.warning("Groq stream: Model '%s' may not support images. Sending without.", llm_model) images_to_send = None elif images_to_send and len(images_to_send) > 5: logger.warning("Groq stream: Taking first 5 of %s images for vision model.", len(images_to_send)) images_to_send = images_to_send[:5] # Build message list if not provided if not messages: messages = [] if system_message: messages.append({"role": "system", "content": system_message}) if images_to_send: content_blocks = [{"type": "text", "text": user_message}] for img_b64 in images_to_send: content_blocks.append({ "type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{img_b64}"}, }) messages.append({"role": "user", "content": content_blocks}) else: if user_message: messages.append({"role": "user", "content": user_message}) payload = { "model": llm_model, "messages": messages, "temperature": temperature, "max_completion_tokens": max_tokens, # Use correct Groq param "top_p": top_p, "stream": True, } # Remove None values payload = {k: v for k, v in payload.items() if v is not None} logger.info(f"Streaming request to Groq: model={llm_model}") session = None try: # Create session with custom connector for Windows connector = aiohttp.TCPConnector(force_close=True) if sys.platform == 'win32' else None session = aiohttp.ClientSession( timeout=aiohttp.ClientTimeout(total=timeout), connector=connector ) async with session.post(groq_url, headers=headers, json=payload) as response: response.raise_for_status() async for raw_line in response.content: line = raw_line.decode("utf-8").strip() if not line: continue if line.startswith("data: "): data_str = line[len("data: ") :].strip() else: data_str = line if data_str == "[DONE]": break try: data_json = json.loads(data_str) choices = data_json.get("choices", []) if choices: delta = choices[0].get("delta", {}) content_piece = delta.get("content") if content_piece: yield content_piece except json.JSONDecodeError: logger.warning(f"Could not decode JSON line from Groq stream: {data_str}") except Exception as e: logger.error(f"Error during Groq streaming: {e}", exc_info=True) yield f"[Groq streaming error: {e}]" finally: # Ensure session is properly closed if session and not session.closed: await session.close() # Small delay to allow cleanup on Windows if sys.platform == 'win32': await asyncio.sleep(0.1) return # Existing fallback logic for other providers if provider_lower not in ["ollama", "openai", "openrouter", "transformers", "hf", "local", "groq", "gemini"]: logger.warning(f"Streaming not implemented for provider '{llm_provider}'. Falling back to non-streaming.") try: full_response_data = await send_request( llm_provider=llm_provider, base_ip=base_ip, port=port, images=base64_images, llm_model=llm_model, system_message=system_message, user_message=user_message, messages=messages, seed=seed, temperature=temperature, max_tokens=max_tokens, random=random, top_k=top_k, top_p=top_p, repeat_penalty=repeat_penalty, stop=stop, keep_alive=keep_alive, llm_api_key=llm_api_key ) if isinstance(full_response_data, dict): if "choices" in full_response_data and full_response_data["choices"]: message = full_response_data["choices"][0].get("message", {}) content = message.get("content", "") if content: yield content elif "response" in full_response_data: if full_response_data["response"]: yield full_response_data["response"] elif "candidates" in full_response_data and full_response_data.get("candidates"): try: content = full_response_data["candidates"][0]["content"]["parts"][0]["text"] if content: yield content except (KeyError, IndexError, TypeError): logger.warning(f"Could not parse Gemini response format: {full_response_data}") else: logger.error(f"Unexpected non-streaming response format: {full_response_data}") elif isinstance(full_response_data, str): if full_response_data: yield full_response_data else: logger.error(f"Unexpected non-streaming response type: {type(full_response_data)}") except Exception as e: logger.error(f"Error in fallback non-streaming request for {llm_provider}: {e}", exc_info=True) yield f"[Error: {e}]" return # Stop generation after yielding the fallback response def _remove_thinking_tags(text: str) -> str: """Remove ... blocks from text, including the tags themselves.""" import re # Pattern to match ... or ◁think▷...◁/think▷ pattern = r'.*?|◁think▷.*?◁/think▷' cleaned = re.sub(pattern, '', text, flags=re.DOTALL) # Clean up any extra whitespace/newlines left behind cleaned = re.sub(r'\n\s*\n\s*\n', '\n\n', cleaned) # Replace multiple newlines with double return cleaned.strip() # --- Original Node (for reference or non-streaming use) --- class LLMToolkitTextGenerator: DEFAULT_PROVIDER = "openai" DEFAULT_MODEL: str = "gpt-4o-mini" MODEL_LIST: List[str] = [DEFAULT_MODEL] @classmethod def INPUT_TYPES(cls): return { "required": { "prompt": ("STRING", {"multiline": False, "default": "Write a short story about a robot learning to paint."}), "hide_thinking": ("BOOLEAN", {"default": True, "tooltip": "Hide model thinking process (content between tags)"}) }, "optional": { "context": ("*", {}), "temperature": ("FLOAT", {"default": 0.7, "min": 0.0, "max": 2.0, "step": 0.05}), "max_tokens": ("INT", {"default": 1024, "min": 1, "max": 65536}), "top_p": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 1.0, "step": 0.05}), "top_k": ("INT", {"default": 40, "min": 0, "max": 500}), "seed": ("INT", {"default": -1, "min": -1, "max": 2147483647, "tooltip": "-1 = random seed"}) }, "hidden": { "llm_model": ("STRING", {"default": cls.DEFAULT_MODEL}) } } RETURN_TYPES = ("*", "STRING") RETURN_NAMES = ("context", "text") FUNCTION = "generate" CATEGORY = "🔗llm_toolkit/generators" OUTPUT_NODE = True # Keeps the text widget for non-streaming version def generate(self, prompt, hide_thinking, context=None, temperature=0.7, max_tokens=1024, top_p=0.9, top_k=40, seed=-1, llm_model=None): # ... (original generate logic using run_async(send_request(...))) ... # This function remains mostly the same as the user provided, # calling the original non-streaming send_request. # We'll copy the parameter processing logic from the streaming version # for consistency, but it will call the non-streaming send_request. try: # Base parameter defaults params = { "llm_provider": self.DEFAULT_PROVIDER, "llm_model": llm_model or self.DEFAULT_MODEL, "system_message": "You are a helpful, creative, and concise assistant.", "user_message": prompt, "base_ip": "localhost", "port": "11434", "temperature": temperature, "max_tokens": max_tokens, "top_p": top_p, "top_k": top_k, "seed": seed if seed != -1 else None, "repeat_penalty": 1.1, "stop": None, "keep_alive": True, "messages": [], } provider_config = None if context is not None: if isinstance(context, dict) and "provider_name" in context: provider_config = context elif isinstance(context, dict) and "provider_config" in context: provider_config = context["provider_config"] if provider_config and isinstance(provider_config, dict): if "provider_name" in provider_config: params["llm_provider"] = provider_config["provider_name"] if "api_key" in provider_config: params["llm_api_key"] = provider_config["api_key"] if "base_ip" in provider_config: params["base_ip"] = provider_config["base_ip"] if "port" in provider_config: params["port"] = provider_config["port"] for key in provider_config: if key not in ["provider_name", "llm_model", "api_key", "base_ip", "port", "user_prompt"]: params[key] = provider_config[key] if "user_prompt" in provider_config: params["user_message"] = provider_config["user_prompt"] provider_model = provider_config.get("llm_model", "") if provider_model: params["llm_model"] = provider_model else: params["llm_model"] = "" # Will be handled below if params.get("llm_provider"): provider = str(params["llm_provider"]).lower() if not params.get("llm_model"): PROVIDER_DEFAULTS = {"openai": "gpt-4o-mini", "anthropic": "claude-3-opus-20240229"} fallback = PROVIDER_DEFAULTS.get(provider) if fallback: params["llm_model"] = fallback # --- Groq reasoning-format handling --- if provider == "groq": # Hide thinking via API when requested to avoid post-processing. params["reasoning_format"] = "hidden" if hide_thinking else "raw" # Auto-fetch API key for OpenAI if missing/placeholder if provider == "openai" and (not params.get("llm_api_key") or params["llm_api_key"] in {"", "1234", None}): try: params["llm_api_key"] = get_api_key("OPENAI_API_KEY", "openai") logger.info("generate: Retrieved OpenAI API key via get_api_key helper.") except ValueError as _e: logger.warning(f"generate: get_api_key failed – {_e}") # --- Prompt-Manager support --------------------------------------------------- # If upstream PromptManager attached a prompt_config dict, use the data prompt_cfg = None if context is not None and isinstance(context, dict): prompt_cfg = context.get("prompt_config") if prompt_cfg and isinstance(prompt_cfg, dict): # Override user prompt text if provided if prompt_cfg.get("text"): params["user_message"] = prompt_cfg["text"] # Collect images and video frames (already in base64 from PromptManager) imgs = [] if prompt_cfg.get("image_base64"): img_val = prompt_cfg["image_base64"] if isinstance(img_val, str): imgs.extend([img_val]) elif isinstance(img_val, list): imgs.extend(img_val) # Add extracted video frames if prompt_cfg.get("video_frames_base64"): vid_frames = prompt_cfg["video_frames_base64"] if isinstance(vid_frames, list): imgs.extend(vid_frames) else: imgs.append(vid_frames) params["images"] = imgs if imgs else None # Pass through file paths and URLs for APIs that support them if prompt_cfg.get("file_paths"): params["file_paths"] = prompt_cfg["file_paths"] logger.debug(f"PromptManager: Passing file_paths to API (some APIs process videos/PDFs directly).") if prompt_cfg.get("urls"): params["urls"] = prompt_cfg["urls"] logger.debug(f"PromptManager: Passing URLs to API.") # Audio path remains separate if prompt_cfg.get("audio_path"): params["audio_path"] = prompt_cfg["audio_path"] else: params["images"] = None # --- New: If file_paths contain local video files, extract a few frames --- try: file_paths_value = params.get("file_paths") candidate_paths: List[str] = [] if isinstance(file_paths_value, str) and file_paths_value.strip(): candidate_paths = [file_paths_value.strip()] elif isinstance(file_paths_value, list): candidate_paths = [p for p in file_paths_value if isinstance(p, str) and p.strip()] extracted_from_files: List[str] = [] for p in candidate_paths: if _is_video_file(p) and os.path.isfile(p): extracted_from_files.extend(_extract_video_file_frames_as_b64(p)) if extracted_from_files: if params.get("images") is None: params["images"] = [] if not isinstance(params["images"], list): params["images"] = [params["images"]] # normalize # cap total images to 8 to avoid huge payloads remaining = max(0, 8 - len(params["images"])) params["images"].extend(extracted_from_files[:remaining]) logger.info("Added %d frame(s) extracted from video file paths to images payload.", min(len(extracted_from_files), remaining)) except Exception as e: logger.warning("Failed to process video file paths for frame extraction: %s", e, exc_info=True) log_params = _sanitize_params_for_log(params) logger.info(f"[Non-Streaming] Making LLM request with params: {log_params}") try: # --- CALL NON-STREAMING VERSION --- response_data = run_async( send_request( # Original non-streaming call llm_provider=params["llm_provider"], base_ip=params.get("base_ip", "localhost"), port=params.get("port", "11434"), images=params.get("images"), llm_model=params["llm_model"], system_message=params["system_message"], user_message=params["user_message"], messages=params["messages"], seed=params.get("seed"), temperature=params["temperature"], max_tokens=params["max_tokens"], random=params.get("random", False), top_k=params["top_k"], top_p=params["top_p"], repeat_penalty=params["repeat_penalty"], stop=params.get("stop"), keep_alive=params.get("keep_alive", True), llm_api_key=params.get("llm_api_key"), reasoning_format=params.get("reasoning_format") # Pass reasoning_format ) ) # --- END NON-STREAMING CALL --- except Exception as e: logger.error(f"Error in non-streaming send_request call: {e}", exc_info=True) response_data = {"choices": [{"message": {"content": f"Error calling send_request: {str(e)}"}}]} if response_data is None: content = "Error: Received None response" elif isinstance(response_data, dict): if "choices" in response_data and response_data["choices"]: message = response_data["choices"][0].get("message", {}) content = message.get("content", "") if content is None: content = "Error: Null content in response" elif "response" in response_data: content = response_data["response"] else: content = f"Error: Unexpected format: {str(response_data)}" elif isinstance(response_data, str): content = response_data else: content = f"Error: Unexpected type: {type(response_data)}" # Apply thinking tag removal if requested if hide_thinking and content: content = _remove_thinking_tags(content) if context is not None and isinstance(context, dict): context_out = context.copy() context_out["llm_response"] = content context_out["llm_raw_response"] = response_data else: context_out = {"llm_response": content, "llm_raw_response": response_data, "passthrough_data": context} payload = ContextPayload(content, context_out) return {"ui": {"string": [content]}, "result": (payload, str(payload))} except Exception as e: error_message = f"Error generating text: {str(e)}" logger.error(error_message, exc_info=True) error_output = {"error": error_message, "original_input": context} payload = ContextPayload(error_message, error_output) return {"ui": {"string": [error_message]}, "result": (payload, str(payload))} # --- Helper to avoid dumping large base64 blobs to INFO log ----------------- def _sanitize_params_for_log(d: dict) -> dict: """Return a shallow copy with huge fields replaced by short placeholders.""" out = {} for k, v in d.items(): # Hide large binary blobs (images, video frames) if k in {"images", "video_frames", "base64_images"} and v: # Replace with concise placeholder if isinstance(v, (list, tuple)): out[k] = f"[{len(v)} item(s) base64 omitted]" elif isinstance(v, str): out[k] = "[base64 string omitted]" else: out[k] = "[data omitted]" continue # Mask any values that look like API keys so we never log them in full key_lc = k.lower() if "api_key" in key_lc or key_lc.endswith("_key") or key_lc.endswith("key"): if isinstance(v, str) and v: # Keep first 5 chars for debugging, mask the rest masked = v[:5] + "…" if len(v) > 5 else "***" out[k] = masked else: out[k] = "[key hidden]" continue # Default passthrough else: out[k] = v return out # --- NEW STREAMING NODE --- class LLMToolkitTextGeneratorStream: DEFAULT_PROVIDER = "openai" DEFAULT_MODEL: str = "gpt-4o-mini" MODEL_LIST: List[str] = [DEFAULT_MODEL] @classmethod def INPUT_TYPES(cls): return { "required": { "prompt": ("STRING", {"multiline": False, "default": "Write a detailed description of a futuristic city."}), "hide_thinking": ("BOOLEAN", {"default": True, "tooltip": "Hide model thinking process (content between tags)"}) }, "optional": { "context": ("*", {}), "temperature": ("FLOAT", {"default": 0.7, "min": 0.0, "max": 2.0, "step": 0.05}), "max_tokens": ("INT", {"default": 1024, "min": 1, "max": 65536}), "top_p": ("FLOAT", {"default": 0.9, "min": 0.0, "max": 1.0, "step": 0.05}), "top_k": ("INT", {"default": 40, "min": 0, "max": 500}), "seed": ("INT", {"default": -1, "min": -1, "max": 2147483647, "tooltip": "-1 = random seed"}) }, "hidden": { # <-- Add hidden inputs "unique_id": "UNIQUE_ID", "llm_model": ("STRING", {"default": cls.DEFAULT_MODEL}) }, } RETURN_TYPES = ("*", "STRING") RETURN_NAMES = ("context", "text") FUNCTION = "generate_stream" # <-- Use new function name CATEGORY = "🔗llm_toolkit/generators" OUTPUT_NODE = True # Keep the JS widget logic def generate_stream(self, prompt, hide_thinking, unique_id, llm_model=None, context=None, **kwargs): """ Generates text using the specified provider and streams the response back to the UI via websocket messages. (synchronous wrapper) """ # Wrap the previous async implementation inside an inner coroutine async def _async_generate(): # Previous async body START if PromptServer is None: logger.error("PromptServer not available. Cannot stream.") error_msg = "Streaming requires PromptServer, which is not available." error_output = {"error": error_msg, "original_input": context} payload = ContextPayload(error_msg, error_output) return {"ui": {"string": [error_msg]}, "result": (payload, str(payload))} server = PromptServer.instance full_response_text = "" thinking_buffer = "" inside_thinking = False try: # --- Corrected Parameter Processing Logic --- # 1. Start with node defaults params = { "llm_provider": self.DEFAULT_PROVIDER, "llm_model": llm_model or self.DEFAULT_MODEL, "system_message": "You are a helpful, creative, and concise assistant.", "user_message": prompt, "base_ip": "localhost", "port": "11434", "temperature": kwargs.get('temperature', 0.7), "max_tokens": kwargs.get('max_tokens', 1024), "top_p": kwargs.get('top_p', 0.9), "top_k": kwargs.get('top_k', 40), "seed": kwargs.get('seed', -1), "repeat_penalty": 1.1, "stop": None, "keep_alive": "5m", "messages": [], "llm_api_key": None, } # 2. Apply general settings from the incoming context if context and isinstance(context, dict): for key, value in context.items(): if key in params and value is not None: params[key] = value # 3. Apply specific provider_config settings, which take precedence provider_config = None if context and isinstance(context, dict): if "provider_config" in context and isinstance(context["provider_config"], dict): provider_config = context["provider_config"] elif "provider_name" in context: # Handle flat context from old nodes provider_config = context if provider_config: # First, update all matching keys from provider_config that also exist in params params.update({k: v for k, v in provider_config.items() if k in params and v is not None}) # Second, handle specific key name differences. This ensures provider_name is mapped correctly. if "provider_name" in provider_config and provider_config["provider_name"]: params["llm_provider"] = provider_config["provider_name"] elif context and isinstance(context, dict) and context.get("provider_name"): # Fallback to root-level key if nested provider_config didn't have it params["llm_provider"] = context["provider_name"] if "api_key" in provider_config and provider_config["api_key"]: params["llm_api_key"] = provider_config["api_key"] elif context and isinstance(context, dict) and context.get("api_key"): params["llm_api_key"] = context["api_key"] if "llm_model" in provider_config and provider_config["llm_model"]: params["llm_model"] = provider_config["llm_model"] elif context and isinstance(context, dict) and context.get("llm_model"): params["llm_model"] = context["llm_model"] # 4. The user_message from context or provider_config can be used, # but the node's prompt input has the final say. if context and isinstance(context, dict) and "user_prompt" in context and context["user_prompt"]: params["user_message"] = context.get("user_prompt") if provider_config and "user_prompt" in provider_config and provider_config["user_prompt"]: params["user_message"] = provider_config.get("user_prompt") if prompt: # Node input is final override, if provided params["user_message"] = prompt # --- End Corrected Parameter Processing --- # Finalize model name fallback provider = "unknown" if params.get("llm_provider"): provider = str(params["llm_provider"]).lower() if not params.get("llm_model"): PROVIDER_DEFAULTS = {"openai": "gpt-4o-mini", "anthropic": "claude-3-opus-20240229"} fallback = PROVIDER_DEFAULTS.get(provider) params["llm_model"] = fallback or self.DEFAULT_MODEL # Groq reasoning format if provider == "groq": params["reasoning_format"] = "hidden" if hide_thinking else "raw" # Auto-fetch API key for OpenAI if missing/placeholder if provider == "openai" and (not params.get("llm_api_key") or params["llm_api_key"] in {"", "1234", None}): try: params["llm_api_key"] = get_api_key("OPENAI_API_KEY", "openai") logger.info("generate_stream: Retrieved OpenAI API key via get_api_key helper.") except ValueError as _e: logger.warning(f"generate_stream: get_api_key failed – {_e}") # --- Prompt-Manager support --------------------------------------------------- # If upstream PromptManager attached a prompt_config dict, use the data prompt_cfg = None if context is not None and isinstance(context, dict): prompt_cfg = context.get("prompt_config") if prompt_cfg and isinstance(prompt_cfg, dict): # Override user prompt text if provided if prompt_cfg.get("text"): params["user_message"] = prompt_cfg["text"] # Collect images and video frames (already in base64 from PromptManager) imgs = [] if prompt_cfg.get("image_base64"): img_val = prompt_cfg["image_base64"] if isinstance(img_val, str): imgs.extend([img_val]) elif isinstance(img_val, list): imgs.extend(img_val) # Add extracted video frames if prompt_cfg.get("video_frames_base64"): vid_frames = prompt_cfg["video_frames_base64"] if isinstance(vid_frames, list): imgs.extend(vid_frames) else: imgs.append(vid_frames) params["images"] = imgs if imgs else None # Pass through file paths and URLs for APIs that support them if prompt_cfg.get("file_paths"): params["file_paths"] = prompt_cfg["file_paths"] logger.debug(f"PromptManager: Passing file_paths to API (some APIs process videos/PDFs directly).") if prompt_cfg.get("urls"): params["urls"] = prompt_cfg["urls"] logger.debug(f"PromptManager: Passing URLs to API.") # Audio path remains separate if prompt_cfg.get("audio_path"): params["audio_path"] = prompt_cfg["audio_path"] else: params["images"] = None # --- New: If file_paths contain local video files, extract a few frames --- try: file_paths_value = params.get("file_paths") candidate_paths: List[str] = [] if isinstance(file_paths_value, str) and file_paths_value.strip(): candidate_paths = [file_paths_value.strip()] elif isinstance(file_paths_value, list): candidate_paths = [p for p in file_paths_value if isinstance(p, str) and p.strip()] extracted_from_files: List[str] = [] for p in candidate_paths: if _is_video_file(p) and os.path.isfile(p): extracted_from_files.extend(_extract_video_file_frames_as_b64(p)) if extracted_from_files: if params.get("images") is None: params["images"] = [] if not isinstance(params["images"], list): params["images"] = [params["images"]] # normalize # cap total images to 8 to avoid huge payloads remaining = max(0, 8 - len(params["images"])) params["images"].extend(extracted_from_files[:remaining]) logger.info( "[Streaming] Added %d frame(s) extracted from video file paths to images payload.", min(len(extracted_from_files), remaining), ) except Exception as e: logger.warning( "[Streaming] Failed to process video file paths for frame extraction: %s", e, exc_info=True, ) log_params = _sanitize_params_for_log(params) logger.info( f"[Streaming] Initiating LLM stream with params: {log_params} for node {unique_id}" ) # --- Send START message --- server.send_sync("llmtoolkit.stream.start", {"node": unique_id}, sid=server.client_id) # --- Initiate and process the stream --- stream_generator = send_request_stream( llm_provider=params["llm_provider"], base_ip=params.get("base_ip", "localhost"), port=params.get("port", "11434"), llm_model=params["llm_model"], system_message=params["system_message"], user_message=params["user_message"], messages=params["messages"], seed=params.get("seed"), temperature=params["temperature"], max_tokens=params["max_tokens"], random=params.get("random", False), top_k=params["top_k"], top_p=params["top_p"], repeat_penalty=params["repeat_penalty"], stop=params.get("stop"), keep_alive=params.get("keep_alive", True), llm_api_key=params.get("llm_api_key"), base64_images=params.get("images"), ) async for chunk in stream_generator: if chunk: full_response_text += chunk # Handle thinking tag filtering for streaming if hide_thinking: # Track thinking state and buffer content chunk_to_send = "" i = 0 while i < len(chunk): if not inside_thinking: # Check for start of thinking tag if chunk[i:i+7] == '': inside_thinking = True thinking_buffer = '' i += 7 continue elif chunk[i:i+9] == '◁think▷': # <-- Add Kimi-VL start tag inside_thinking = True thinking_buffer = '◁think▷' i += 9 continue else: chunk_to_send += chunk[i] else: # Inside thinking block, buffer until we find closing tag thinking_buffer += chunk[i] if thinking_buffer.endswith(''): inside_thinking = False thinking_buffer = "" i += 1 continue elif thinking_buffer.endswith('◁/think▷'): # <-- Add Kimi-VL end tag inside_thinking = False thinking_buffer = "" i += 1 continue i += 1 # Only send non-thinking content if chunk_to_send: server.send_sync( "llmtoolkit.stream.chunk", {"node": unique_id, "text": chunk_to_send}, sid=server.client_id, ) else: # Send all content if not hiding thinking server.send_sync( "llmtoolkit.stream.chunk", {"node": unique_id, "text": chunk}, sid=server.client_id, ) await asyncio.sleep(0.001) # Apply thinking tag removal to final text if requested final_text = full_response_text if hide_thinking: final_text = _remove_thinking_tags(full_response_text) logger.info( f"[Streaming] Finished for node {unique_id}. Total length: {len(final_text)}" ) server.send_sync( "llmtoolkit.stream.end", {"node": unique_id, "final_text": final_text}, sid=server.client_id, ) # --- Prepare final context output --- if context is not None and isinstance(context, dict): context_out = context.copy() context_out["llm_response"] = final_text context_out["llm_raw_response"] = { "status": "Streamed successfully", "final_length": len(final_text), } else: context_out = { "llm_response": final_text, "llm_raw_response": { "status": "Streamed successfully", "final_length": len(final_text), }, "passthrough_data": context, } payload = ContextPayload(final_text, context_out) return {"ui": {"string": [final_text]}, "result": (payload, str(payload))} except Exception as e: error_message = f"Error during streaming generation: {str(e)}" logger.error(error_message, exc_info=True) if server and unique_id: server.send_sync( "llmtoolkit.stream.error", {"node": unique_id, "error": error_message}, sid=server.client_id, ) error_output = { "error": error_message, "partial_response": full_response_text, "original_input": context, } payload = ContextPayload(error_message, error_output) return { "ui": {"string": [f"Error: {error_message}\nPartial: {full_response_text}"]}, "result": (payload, str(payload)), } # Previous async body END # Execute the inner coroutine and return its result synchronously return run_async(_async_generate()) # Node Mappings for ComfyUI NODE_CLASS_MAPPINGS = { "LLMToolkitTextGenerator": LLMToolkitTextGenerator, # Keep original "LLMToolkitTextGeneratorStream": LLMToolkitTextGeneratorStream # Add streaming version } NODE_DISPLAY_NAME_MAPPINGS = { "LLMToolkitTextGenerator": "Generate Text (🔗LLMToolkit)", "LLMToolkitTextGeneratorStream": "Generate Text Stream (🔗LLMToolkit)" # New display name }