feat(entrypoint): add comprehensive error handling and RTX 50 series support

Enhance entrypoint script with robust error handling, PyTorch validation, and RTX 50 support

PyTorch CUDA Validation:
- Add test_pytorch_cuda() function to verify CUDA availability and enumerate devices
- Display compute capabilities for all detected GPUs during startup
- Validate PyTorch installation before attempting Sage Attention builds

Enhanced GPU Detection:
- Update RTX 50 series architecture targeting to compute capability 12.0 (sm_120)
- Improve mixed-generation GPU handling with better compatibility logic
- Add comprehensive logging for GPU detection and strategy selection

Triton Version Management:
- Add intelligent fallback system for Triton installation failures
- RTX 50 series: Try latest → pre-release → stable fallback chain
- RTX 20 series: Enforce Triton 3.2.0 for compatibility
- Enhanced error recovery when specific versions fail

Build Error Handling:
- Add proper error propagation throughout Sage Attention build process
- Implement graceful degradation when builds fail (ComfyUI still starts)
- Comprehensive logging for troubleshooting build issues
- Better cleanup and recovery from partial build failures

Architecture-Specific Optimizations:
- Proper TORCH_CUDA_ARCH_LIST targeting for mixed GPU environments
- RTX 50 series: Use sm_120 for Blackwell architecture support
- Multi-GPU compilation targeting prevents architecture mismatches
- Intelligent version selection (v1.0 for RTX 20, v2.2 for modern GPUs)

Command Line Integration:
- Enhanced argument handling preserves user-provided flags
- Automatic --use-sage-attention injection when builds succeed
- Support for both default startup and custom user commands
- SAGE_ATTENTION_AVAILABLE environment variable for external integration

This transforms the entrypoint from a basic startup script into a comprehensive
GPU optimization and build management system with enterprise-grade error handling.
This commit is contained in:
clsferguson 2025-09-22 09:28:12 -06:00 committed by GitHub
parent f2b49b294b
commit cdac5a8b32
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -16,6 +16,24 @@ log() {
echo "[$(date '+%H:%M:%S')] $1"
}
# Function to test PyTorch CUDA compatibility
test_pytorch_cuda() {
python -c "
import torch
import sys
if not torch.cuda.is_available():
print('[ERROR] PyTorch CUDA not available')
sys.exit(1)
device_count = torch.cuda.device_count()
print(f'[TEST] PyTorch CUDA available with {device_count} devices')
for i in range(device_count):
props = torch.cuda.get_device_properties(i)
print(f'[TEST] GPU {i}: {props.name} (Compute {props.major}.{props.minor})')
" 2>/dev/null
}
# Function to detect all GPUs and their generations
detect_gpu_generations() {
local gpu_info=$(nvidia-smi --query-gpu=name --format=csv,noheader,nounits 2>/dev/null || echo "")
@ -59,6 +77,13 @@ detect_gpu_generations() {
export GPU_COUNT=$gpu_count
log "Detection summary: RTX20=$has_rtx20, RTX30=$has_rtx30, RTX40=$has_rtx40, RTX50=$has_rtx50"
# Test PyTorch CUDA compatibility
if test_pytorch_cuda; then
log "PyTorch CUDA compatibility confirmed"
else
log "WARNING: PyTorch CUDA compatibility issues detected"
fi
}
# Function to determine optimal Sage Attention strategy for mixed GPUs
@ -93,19 +118,26 @@ install_triton_version() {
case "$SAGE_STRATEGY" in
"mixed_with_rtx20"|"rtx20_only")
log "Installing Triton 3.2.0 for RTX 20 series compatibility"
python -m pip install --force-reinstall triton==3.2.0
python -m pip install --force-reinstall "triton==3.2.0" || {
log "WARNING: Failed to install specific Triton version, using default"
python -m pip install --force-reinstall triton || true
}
;;
"rtx50_capable")
log "Installing latest Triton for RTX 50 series"
python -m pip install --force-reinstall triton
;;
"rtx30_40_optimized")
log "Installing optimal Triton for RTX 30/40 series"
python -m pip install --force-reinstall triton
# Try latest first, fallback to pre-release if needed
python -m pip install --force-reinstall triton || \
python -m pip install --force-reinstall --pre triton || {
log "WARNING: Failed to install latest Triton, using stable"
python -m pip install --force-reinstall "triton>=3.2.0" || true
}
;;
*)
log "Installing default Triton version"
python -m pip install --force-reinstall triton
log "Installing latest stable Triton"
python -m pip install --force-reinstall triton || {
log "WARNING: Triton installation failed, continuing without"
return 1
}
;;
esac
}
@ -123,7 +155,7 @@ build_sage_attention_mixed() {
[ "$DETECTED_RTX20" = "true" ] && cuda_arch_list="${cuda_arch_list}7.5;"
[ "$DETECTED_RTX30" = "true" ] && cuda_arch_list="${cuda_arch_list}8.6;"
[ "$DETECTED_RTX40" = "true" ] && cuda_arch_list="${cuda_arch_list}8.9;"
[ "$DETECTED_RTX50" = "true" ] && cuda_arch_list="${cuda_arch_list}9.0;"
[ "$DETECTED_RTX50" = "true" ] && cuda_arch_list="${cuda_arch_list}12.0;"
# Remove trailing semicolon
cuda_arch_list=${cuda_arch_list%;}
@ -137,12 +169,12 @@ build_sage_attention_mixed() {
log "Cloning Sage Attention v1.0 for RTX 20 series compatibility"
if [ -d "SageAttention/.git" ]; then
cd SageAttention
git fetch --depth 1 origin
git checkout v1.0 2>/dev/null || git checkout -b v1.0 origin/v1.0
git reset --hard origin/v1.0
git fetch --depth 1 origin || return 1
git checkout v1.0 2>/dev/null || git checkout -b v1.0 origin/v1.0 || return 1
git reset --hard origin/v1.0 || return 1
else
rm -rf SageAttention
git clone --depth 1 https://github.com/thu-ml/SageAttention.git -b v1.0
git clone --depth 1 https://github.com/thu-ml/SageAttention.git -b v1.0 || return 1
cd SageAttention
fi
;;
@ -150,11 +182,11 @@ build_sage_attention_mixed() {
log "Cloning latest Sage Attention for modern GPUs"
if [ -d "SageAttention/.git" ]; then
cd SageAttention
git fetch --depth 1 origin
git reset --hard origin/main
git fetch --depth 1 origin || return 1
git reset --hard origin/main || return 1
else
rm -rf SageAttention
git clone --depth 1 https://github.com/thu-ml/SageAttention.git
git clone --depth 1 https://github.com/thu-ml/SageAttention.git || return 1
cd SageAttention
fi
;;
@ -166,13 +198,13 @@ build_sage_attention_mixed() {
# Create strategy-specific built flag
echo "$SAGE_STRATEGY" > "$SAGE_ATTENTION_BUILT_FLAG"
log "Sage Attention built successfully for strategy: $SAGE_STRATEGY"
cd "$BASE_DIR"
return 0
else
log "ERROR: Sage Attention build failed"
cd "$BASE_DIR"
return 1
fi
cd "$BASE_DIR"
}
# Function to check if current build matches detected GPUs
@ -234,21 +266,24 @@ setup_sage_attention() {
log "Building Sage Attention..."
# Install appropriate Triton version first
install_triton_version
# Build Sage Attention
if build_sage_attention_mixed; then
# Test installation
if test_sage_attention; then
export SAGE_ATTENTION_AVAILABLE=1
log "Sage Attention setup completed successfully"
log "SAGE_ATTENTION_AVAILABLE=1 (will use --use-sage-attention flag)"
if install_triton_version; then
# Build Sage Attention
if build_sage_attention_mixed; then
# Test installation
if test_sage_attention; then
export SAGE_ATTENTION_AVAILABLE=1
log "Sage Attention setup completed successfully"
log "SAGE_ATTENTION_AVAILABLE=1 (will use --use-sage-attention flag)"
else
log "WARNING: Sage Attention build succeeded but import test failed"
export SAGE_ATTENTION_AVAILABLE=0
fi
else
log "WARNING: Sage Attention build succeeded but import test failed"
log "ERROR: Sage Attention build failed"
export SAGE_ATTENTION_AVAILABLE=0
fi
else
log "ERROR: Sage Attention build failed"
log "ERROR: Triton installation failed, skipping Sage Attention build"
export SAGE_ATTENTION_AVAILABLE=0
fi
else