diff --git a/entrypoint.sh b/entrypoint.sh index e60a4aaa2..0f92c7dea 100644 --- a/entrypoint.sh +++ b/entrypoint.sh @@ -16,6 +16,24 @@ log() { echo "[$(date '+%H:%M:%S')] $1" } +# Function to test PyTorch CUDA compatibility +test_pytorch_cuda() { + python -c " +import torch +import sys +if not torch.cuda.is_available(): + print('[ERROR] PyTorch CUDA not available') + sys.exit(1) + +device_count = torch.cuda.device_count() +print(f'[TEST] PyTorch CUDA available with {device_count} devices') + +for i in range(device_count): + props = torch.cuda.get_device_properties(i) + print(f'[TEST] GPU {i}: {props.name} (Compute {props.major}.{props.minor})') +" 2>/dev/null +} + # Function to detect all GPUs and their generations detect_gpu_generations() { local gpu_info=$(nvidia-smi --query-gpu=name --format=csv,noheader,nounits 2>/dev/null || echo "") @@ -59,6 +77,13 @@ detect_gpu_generations() { export GPU_COUNT=$gpu_count log "Detection summary: RTX20=$has_rtx20, RTX30=$has_rtx30, RTX40=$has_rtx40, RTX50=$has_rtx50" + + # Test PyTorch CUDA compatibility + if test_pytorch_cuda; then + log "PyTorch CUDA compatibility confirmed" + else + log "WARNING: PyTorch CUDA compatibility issues detected" + fi } # Function to determine optimal Sage Attention strategy for mixed GPUs @@ -93,19 +118,26 @@ install_triton_version() { case "$SAGE_STRATEGY" in "mixed_with_rtx20"|"rtx20_only") log "Installing Triton 3.2.0 for RTX 20 series compatibility" - python -m pip install --force-reinstall triton==3.2.0 + python -m pip install --force-reinstall "triton==3.2.0" || { + log "WARNING: Failed to install specific Triton version, using default" + python -m pip install --force-reinstall triton || true + } ;; "rtx50_capable") log "Installing latest Triton for RTX 50 series" - python -m pip install --force-reinstall triton - ;; - "rtx30_40_optimized") - log "Installing optimal Triton for RTX 30/40 series" - python -m pip install --force-reinstall triton + # Try latest first, fallback to pre-release if needed + python -m pip install --force-reinstall triton || \ + python -m pip install --force-reinstall --pre triton || { + log "WARNING: Failed to install latest Triton, using stable" + python -m pip install --force-reinstall "triton>=3.2.0" || true + } ;; *) - log "Installing default Triton version" - python -m pip install --force-reinstall triton + log "Installing latest stable Triton" + python -m pip install --force-reinstall triton || { + log "WARNING: Triton installation failed, continuing without" + return 1 + } ;; esac } @@ -123,7 +155,7 @@ build_sage_attention_mixed() { [ "$DETECTED_RTX20" = "true" ] && cuda_arch_list="${cuda_arch_list}7.5;" [ "$DETECTED_RTX30" = "true" ] && cuda_arch_list="${cuda_arch_list}8.6;" [ "$DETECTED_RTX40" = "true" ] && cuda_arch_list="${cuda_arch_list}8.9;" - [ "$DETECTED_RTX50" = "true" ] && cuda_arch_list="${cuda_arch_list}9.0;" + [ "$DETECTED_RTX50" = "true" ] && cuda_arch_list="${cuda_arch_list}12.0;" # Remove trailing semicolon cuda_arch_list=${cuda_arch_list%;} @@ -137,12 +169,12 @@ build_sage_attention_mixed() { log "Cloning Sage Attention v1.0 for RTX 20 series compatibility" if [ -d "SageAttention/.git" ]; then cd SageAttention - git fetch --depth 1 origin - git checkout v1.0 2>/dev/null || git checkout -b v1.0 origin/v1.0 - git reset --hard origin/v1.0 + git fetch --depth 1 origin || return 1 + git checkout v1.0 2>/dev/null || git checkout -b v1.0 origin/v1.0 || return 1 + git reset --hard origin/v1.0 || return 1 else rm -rf SageAttention - git clone --depth 1 https://github.com/thu-ml/SageAttention.git -b v1.0 + git clone --depth 1 https://github.com/thu-ml/SageAttention.git -b v1.0 || return 1 cd SageAttention fi ;; @@ -150,11 +182,11 @@ build_sage_attention_mixed() { log "Cloning latest Sage Attention for modern GPUs" if [ -d "SageAttention/.git" ]; then cd SageAttention - git fetch --depth 1 origin - git reset --hard origin/main + git fetch --depth 1 origin || return 1 + git reset --hard origin/main || return 1 else rm -rf SageAttention - git clone --depth 1 https://github.com/thu-ml/SageAttention.git + git clone --depth 1 https://github.com/thu-ml/SageAttention.git || return 1 cd SageAttention fi ;; @@ -166,13 +198,13 @@ build_sage_attention_mixed() { # Create strategy-specific built flag echo "$SAGE_STRATEGY" > "$SAGE_ATTENTION_BUILT_FLAG" log "Sage Attention built successfully for strategy: $SAGE_STRATEGY" + cd "$BASE_DIR" return 0 else log "ERROR: Sage Attention build failed" + cd "$BASE_DIR" return 1 fi - - cd "$BASE_DIR" } # Function to check if current build matches detected GPUs @@ -234,21 +266,24 @@ setup_sage_attention() { log "Building Sage Attention..." # Install appropriate Triton version first - install_triton_version - - # Build Sage Attention - if build_sage_attention_mixed; then - # Test installation - if test_sage_attention; then - export SAGE_ATTENTION_AVAILABLE=1 - log "Sage Attention setup completed successfully" - log "SAGE_ATTENTION_AVAILABLE=1 (will use --use-sage-attention flag)" + if install_triton_version; then + # Build Sage Attention + if build_sage_attention_mixed; then + # Test installation + if test_sage_attention; then + export SAGE_ATTENTION_AVAILABLE=1 + log "Sage Attention setup completed successfully" + log "SAGE_ATTENTION_AVAILABLE=1 (will use --use-sage-attention flag)" + else + log "WARNING: Sage Attention build succeeded but import test failed" + export SAGE_ATTENTION_AVAILABLE=0 + fi else - log "WARNING: Sage Attention build succeeded but import test failed" + log "ERROR: Sage Attention build failed" export SAGE_ATTENTION_AVAILABLE=0 fi else - log "ERROR: Sage Attention build failed" + log "ERROR: Triton installation failed, skipping Sage Attention build" export SAGE_ATTENTION_AVAILABLE=0 fi else