mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-29 07:40:21 +08:00
Compare commits
16 Commits
b95b132230
...
fe725e24d5
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fe725e24d5 | ||
|
|
1a20656448 | ||
|
|
0f11869d55 | ||
|
|
5943fbf457 | ||
|
|
a60b7b86c5 | ||
|
|
2e9d51680a | ||
|
|
50d6e1caf4 | ||
|
|
ac12f77bed | ||
|
|
0df9e96683 | ||
|
|
3cfe58d0c3 | ||
|
|
8c0f498a23 | ||
|
|
1a410446e3 | ||
|
|
494dce9a36 | ||
|
|
ddad64a4bf | ||
|
|
c06b18a014 | ||
|
|
e86ffb0ea6 |
2
.github/workflows/test-ci.yml
vendored
2
.github/workflows/test-ci.yml
vendored
@ -20,7 +20,6 @@ jobs:
|
||||
test-stable:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
max-parallel: 1 # This forces sequential execution
|
||||
matrix:
|
||||
# os: [macos, linux, windows]
|
||||
# os: [macos, linux]
|
||||
@ -75,7 +74,6 @@ jobs:
|
||||
test-unix-nightly:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
max-parallel: 1 # This forces sequential execution
|
||||
matrix:
|
||||
# os: [macos, linux]
|
||||
os: [linux]
|
||||
|
||||
@ -11,6 +11,10 @@ stderr_interceptor = None
|
||||
|
||||
|
||||
class LogInterceptor(io.TextIOWrapper):
|
||||
# Maximum logs to buffer between flushes to prevent unbounded memory growth
|
||||
# if callbacks persistently fail. 10000 entries is ~2-5MB depending on message size.
|
||||
MAX_PENDING_LOGS = 10000
|
||||
|
||||
def __init__(self, stream, *args, **kwargs):
|
||||
buffer = stream.buffer
|
||||
encoding = stream.encoding
|
||||
@ -23,6 +27,9 @@ class LogInterceptor(io.TextIOWrapper):
|
||||
entry = {"t": datetime.now().isoformat(), "m": data}
|
||||
with self._lock:
|
||||
self._logs_since_flush.append(entry)
|
||||
# Enforce max size to prevent OOM if callbacks persistently fail
|
||||
if len(self._logs_since_flush) > self.MAX_PENDING_LOGS:
|
||||
self._logs_since_flush = self._logs_since_flush[-self.MAX_PENDING_LOGS:]
|
||||
|
||||
# Simple handling for cr to overwrite the last output if it isnt a full line
|
||||
# else logs just get full of progress messages
|
||||
@ -32,10 +39,21 @@ class LogInterceptor(io.TextIOWrapper):
|
||||
super().write(data)
|
||||
|
||||
def flush(self):
|
||||
super().flush()
|
||||
try:
|
||||
super().flush()
|
||||
except OSError as e:
|
||||
# errno 22 (EINVAL) can occur on Windows with piped/redirected streams
|
||||
# This is safe to ignore as write() already succeeded
|
||||
if e.errno != 22:
|
||||
raise
|
||||
if not self._logs_since_flush:
|
||||
return
|
||||
# Copy to prevent callback mutations from affecting retry on failure
|
||||
logs_to_send = list(self._logs_since_flush)
|
||||
for cb in self._flush_callbacks:
|
||||
cb(self._logs_since_flush)
|
||||
self._logs_since_flush = []
|
||||
cb(logs_to_send)
|
||||
# Only clear after all callbacks succeed - if any raises, logs remain for retry
|
||||
self._logs_since_flush = []
|
||||
|
||||
def on_flush(self, callback):
|
||||
self._flush_callbacks.append(callback)
|
||||
|
||||
@ -3,8 +3,8 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from comfy.ldm.modules.diffusionmodules.model import ResnetBlock, VideoConv3d
|
||||
from comfy.ldm.hunyuan_video.vae_refiner import RMS_norm
|
||||
import model_management
|
||||
import model_patcher
|
||||
import comfy.model_management
|
||||
import comfy.model_patcher
|
||||
|
||||
class SRResidualCausalBlock3D(nn.Module):
|
||||
def __init__(self, channels: int):
|
||||
@ -103,13 +103,13 @@ UPSAMPLERS = {
|
||||
|
||||
class HunyuanVideo15SRModel():
|
||||
def __init__(self, model_type, config):
|
||||
self.load_device = model_management.vae_device()
|
||||
offload_device = model_management.vae_offload_device()
|
||||
self.dtype = model_management.vae_dtype(self.load_device)
|
||||
self.load_device = comfy.model_management.vae_device()
|
||||
offload_device = comfy.model_management.vae_offload_device()
|
||||
self.dtype = comfy.model_management.vae_dtype(self.load_device)
|
||||
self.model_class = UPSAMPLERS.get(model_type)
|
||||
self.model = self.model_class(**config).eval()
|
||||
|
||||
self.patcher = model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
self.patcher = comfy.model_patcher.ModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
|
||||
def load_sd(self, sd):
|
||||
return self.model.load_state_dict(sd, strict=True)
|
||||
@ -118,5 +118,5 @@ class HunyuanVideo15SRModel():
|
||||
return self.model.state_dict()
|
||||
|
||||
def resample_latent(self, latent):
|
||||
model_management.load_model_gpu(self.patcher)
|
||||
comfy.model_management.load_model_gpu(self.patcher)
|
||||
return self.model(latent.to(self.load_device))
|
||||
|
||||
@ -22,7 +22,6 @@ from enum import Enum
|
||||
from comfy.cli_args import args, PerformanceFeature
|
||||
import torch
|
||||
import sys
|
||||
import importlib
|
||||
import platform
|
||||
import weakref
|
||||
import gc
|
||||
@ -349,10 +348,22 @@ try:
|
||||
except:
|
||||
rocm_version = (6, -1)
|
||||
|
||||
def aotriton_supported(gpu_arch):
|
||||
path = torch.__path__[0]
|
||||
path = os.path.join(os.path.join(path, "lib"), "aotriton.images")
|
||||
gfx = set(map(lambda a: a[4:], filter(lambda a: a.startswith("amd-gfx"), os.listdir(path))))
|
||||
if gpu_arch in gfx:
|
||||
return True
|
||||
if "{}x".format(gpu_arch[:-1]) in gfx:
|
||||
return True
|
||||
if "{}xx".format(gpu_arch[:-2]) in gfx:
|
||||
return True
|
||||
return False
|
||||
|
||||
logging.info("AMD arch: {}".format(arch))
|
||||
logging.info("ROCm version: {}".format(rocm_version))
|
||||
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||
if importlib.util.find_spec('triton') is not None: # AMD efficient attention implementation depends on triton. TODO: better way of detecting if it's compiled in or not.
|
||||
if aotriton_supported(arch): # AMD efficient attention implementation depends on aotriton.
|
||||
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
|
||||
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
|
||||
ENABLE_PYTORCH_ATTENTION = True
|
||||
|
||||
@ -479,8 +479,8 @@ class VAE:
|
||||
self.first_stage_model = comfy.ldm.lightricks.vae.causal_video_autoencoder.VideoVAE(version=version, config=vae_config)
|
||||
self.latent_channels = 128
|
||||
self.latent_dim = 3
|
||||
self.memory_used_decode = lambda shape, dtype: (900 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||
self.memory_used_encode = lambda shape, dtype: (70 * max(shape[2], 7) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
|
||||
self.memory_used_decode = lambda shape, dtype: (1200 * shape[2] * shape[3] * shape[4] * (8 * 8 * 8)) * model_management.dtype_size(dtype)
|
||||
self.memory_used_encode = lambda shape, dtype: (80 * max(shape[2], 7) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
|
||||
self.upscale_ratio = (lambda a: max(0, a * 8 - 7), 32, 32)
|
||||
self.upscale_index_formula = (8, 32, 32)
|
||||
self.downscale_ratio = (lambda a: max(0, math.floor((a + 7) / 8)), 32, 32)
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
# This file is automatically generated by the build process when version is
|
||||
# updated in pyproject.toml.
|
||||
__version__ = "0.8.0"
|
||||
__version__ = "0.8.2"
|
||||
|
||||
@ -1 +1 @@
|
||||
comfyui_manager==4.0.4
|
||||
comfyui_manager==4.0.5
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "ComfyUI"
|
||||
version = "0.8.0"
|
||||
version = "0.8.2"
|
||||
readme = "README.md"
|
||||
license = { file = "LICENSE" }
|
||||
requires-python = ">=3.10"
|
||||
|
||||
321
tests-unit/app_test/test_logger.py
Normal file
321
tests-unit/app_test/test_logger.py
Normal file
@ -0,0 +1,321 @@
|
||||
"""Tests for the logger module, specifically LogInterceptor."""
|
||||
|
||||
import io
|
||||
import pytest
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
|
||||
class TestLogInterceptorFlush:
|
||||
"""Test that LogInterceptor.flush() handles OSError gracefully."""
|
||||
|
||||
def test_flush_handles_errno_22(self):
|
||||
"""Test that flush() catches OSError with errno 22 and still executes callbacks."""
|
||||
# We can't easily mock the parent flush, so we test the behavior by
|
||||
# creating a LogInterceptor and verifying the flush method exists
|
||||
# with the try-except structure.
|
||||
|
||||
# Read the source to verify the fix is in place
|
||||
import inspect
|
||||
from app.logger import LogInterceptor
|
||||
|
||||
source = inspect.getsource(LogInterceptor.flush)
|
||||
|
||||
# Verify the try-except structure is present
|
||||
assert 'try:' in source
|
||||
assert 'super().flush()' in source
|
||||
assert 'except OSError as e:' in source
|
||||
assert 'e.errno != 22' in source or 'e.errno == 22' in source
|
||||
|
||||
def test_flush_callback_execution(self):
|
||||
"""Test that flush callbacks are executed."""
|
||||
from app.logger import LogInterceptor
|
||||
|
||||
# Create a proper stream for LogInterceptor
|
||||
import sys
|
||||
|
||||
# Use a StringIO-based approach with a real buffer
|
||||
class MockStream:
|
||||
def __init__(self):
|
||||
self._buffer = io.BytesIO()
|
||||
self.encoding = 'utf-8'
|
||||
self.line_buffering = False
|
||||
|
||||
@property
|
||||
def buffer(self):
|
||||
return self._buffer
|
||||
|
||||
mock_stream = MockStream()
|
||||
interceptor = LogInterceptor(mock_stream)
|
||||
|
||||
# Register a callback
|
||||
callback_results = []
|
||||
interceptor.on_flush(lambda logs: callback_results.append(len(logs)))
|
||||
|
||||
# Add some logs
|
||||
interceptor._logs_since_flush = [
|
||||
{"t": "test", "m": "message1"},
|
||||
{"t": "test", "m": "message2"}
|
||||
]
|
||||
|
||||
# Flush should execute callback
|
||||
interceptor.flush()
|
||||
|
||||
assert len(callback_results) == 1
|
||||
assert callback_results[0] == 2 # Two log entries
|
||||
|
||||
def test_flush_clears_logs_after_callback(self):
|
||||
"""Test that logs are cleared after flush callbacks."""
|
||||
from app.logger import LogInterceptor
|
||||
|
||||
class MockStream:
|
||||
def __init__(self):
|
||||
self._buffer = io.BytesIO()
|
||||
self.encoding = 'utf-8'
|
||||
self.line_buffering = False
|
||||
|
||||
@property
|
||||
def buffer(self):
|
||||
return self._buffer
|
||||
|
||||
mock_stream = MockStream()
|
||||
interceptor = LogInterceptor(mock_stream)
|
||||
|
||||
# Add a dummy callback
|
||||
interceptor.on_flush(lambda logs: None)
|
||||
|
||||
# Add some logs
|
||||
interceptor._logs_since_flush = [{"t": "test", "m": "message"}]
|
||||
|
||||
# Flush
|
||||
interceptor.flush()
|
||||
|
||||
# Logs should be cleared
|
||||
assert interceptor._logs_since_flush == []
|
||||
|
||||
def test_flush_multiple_callbacks_receive_same_logs(self):
|
||||
"""Test that all callbacks receive the same logs, not just the first one."""
|
||||
from app.logger import LogInterceptor
|
||||
|
||||
class MockStream:
|
||||
def __init__(self):
|
||||
self._buffer = io.BytesIO()
|
||||
self.encoding = 'utf-8'
|
||||
self.line_buffering = False
|
||||
|
||||
@property
|
||||
def buffer(self):
|
||||
return self._buffer
|
||||
|
||||
mock_stream = MockStream()
|
||||
interceptor = LogInterceptor(mock_stream)
|
||||
|
||||
# Register multiple callbacks
|
||||
callback1_results = []
|
||||
callback2_results = []
|
||||
callback3_results = []
|
||||
interceptor.on_flush(lambda logs: callback1_results.append(len(logs)))
|
||||
interceptor.on_flush(lambda logs: callback2_results.append(len(logs)))
|
||||
interceptor.on_flush(lambda logs: callback3_results.append(len(logs)))
|
||||
|
||||
# Add some logs
|
||||
interceptor._logs_since_flush = [
|
||||
{"t": "test", "m": "message1"},
|
||||
{"t": "test", "m": "message2"},
|
||||
{"t": "test", "m": "message3"}
|
||||
]
|
||||
|
||||
# Flush should execute all callbacks with the same logs
|
||||
interceptor.flush()
|
||||
|
||||
# All callbacks should have received 3 log entries
|
||||
assert callback1_results == [3]
|
||||
assert callback2_results == [3]
|
||||
assert callback3_results == [3]
|
||||
|
||||
def test_flush_preserves_logs_when_callback_raises(self):
|
||||
"""Test that logs are preserved for retry if a callback raises an exception."""
|
||||
from app.logger import LogInterceptor
|
||||
|
||||
class MockStream:
|
||||
def __init__(self):
|
||||
self._buffer = io.BytesIO()
|
||||
self.encoding = 'utf-8'
|
||||
self.line_buffering = False
|
||||
|
||||
@property
|
||||
def buffer(self):
|
||||
return self._buffer
|
||||
|
||||
mock_stream = MockStream()
|
||||
interceptor = LogInterceptor(mock_stream)
|
||||
|
||||
# Register a callback that raises
|
||||
def raising_callback(logs):
|
||||
raise ValueError("Callback error")
|
||||
|
||||
interceptor.on_flush(raising_callback)
|
||||
|
||||
# Add some logs
|
||||
original_logs = [
|
||||
{"t": "test", "m": "message1"},
|
||||
{"t": "test", "m": "message2"}
|
||||
]
|
||||
interceptor._logs_since_flush = original_logs.copy()
|
||||
|
||||
# Flush should raise
|
||||
with pytest.raises(ValueError, match="Callback error"):
|
||||
interceptor.flush()
|
||||
|
||||
# Logs should be preserved for retry on next flush
|
||||
assert interceptor._logs_since_flush == original_logs
|
||||
|
||||
def test_flush_protects_logs_from_callback_mutation(self):
|
||||
"""Test that callback mutations don't affect preserved logs on failure."""
|
||||
from app.logger import LogInterceptor
|
||||
|
||||
class MockStream:
|
||||
def __init__(self):
|
||||
self._buffer = io.BytesIO()
|
||||
self.encoding = 'utf-8'
|
||||
self.line_buffering = False
|
||||
|
||||
@property
|
||||
def buffer(self):
|
||||
return self._buffer
|
||||
|
||||
mock_stream = MockStream()
|
||||
interceptor = LogInterceptor(mock_stream)
|
||||
|
||||
# First callback mutates the list, second raises
|
||||
def mutating_callback(logs):
|
||||
logs.clear() # Mutate the passed list
|
||||
|
||||
def raising_callback(logs):
|
||||
raise ValueError("Callback error")
|
||||
|
||||
interceptor.on_flush(mutating_callback)
|
||||
interceptor.on_flush(raising_callback)
|
||||
|
||||
# Add some logs
|
||||
original_logs = [
|
||||
{"t": "test", "m": "message1"},
|
||||
{"t": "test", "m": "message2"}
|
||||
]
|
||||
interceptor._logs_since_flush = original_logs.copy()
|
||||
|
||||
# Flush should raise
|
||||
with pytest.raises(ValueError, match="Callback error"):
|
||||
interceptor.flush()
|
||||
|
||||
# Logs should be preserved despite mutation by first callback
|
||||
assert interceptor._logs_since_flush == original_logs
|
||||
|
||||
def test_flush_clears_logs_after_all_callbacks_succeed(self):
|
||||
"""Test that logs are cleared only after all callbacks execute successfully."""
|
||||
from app.logger import LogInterceptor
|
||||
|
||||
class MockStream:
|
||||
def __init__(self):
|
||||
self._buffer = io.BytesIO()
|
||||
self.encoding = 'utf-8'
|
||||
self.line_buffering = False
|
||||
|
||||
@property
|
||||
def buffer(self):
|
||||
return self._buffer
|
||||
|
||||
mock_stream = MockStream()
|
||||
interceptor = LogInterceptor(mock_stream)
|
||||
|
||||
# Register multiple callbacks
|
||||
callback1_results = []
|
||||
callback2_results = []
|
||||
interceptor.on_flush(lambda logs: callback1_results.append(len(logs)))
|
||||
interceptor.on_flush(lambda logs: callback2_results.append(len(logs)))
|
||||
|
||||
# Add some logs
|
||||
interceptor._logs_since_flush = [
|
||||
{"t": "test", "m": "message1"},
|
||||
{"t": "test", "m": "message2"}
|
||||
]
|
||||
|
||||
# Flush should succeed
|
||||
interceptor.flush()
|
||||
|
||||
# All callbacks should have executed
|
||||
assert callback1_results == [2]
|
||||
assert callback2_results == [2]
|
||||
|
||||
# Logs should be cleared after success
|
||||
assert interceptor._logs_since_flush == []
|
||||
|
||||
|
||||
class TestLogInterceptorWrite:
|
||||
"""Test that LogInterceptor.write() works correctly."""
|
||||
|
||||
def test_write_adds_to_logs(self):
|
||||
"""Test that write() adds entries to the log buffer."""
|
||||
from app.logger import LogInterceptor
|
||||
|
||||
class MockStream:
|
||||
def __init__(self):
|
||||
self._buffer = io.BytesIO()
|
||||
self.encoding = 'utf-8'
|
||||
self.line_buffering = False
|
||||
|
||||
@property
|
||||
def buffer(self):
|
||||
return self._buffer
|
||||
|
||||
mock_stream = MockStream()
|
||||
interceptor = LogInterceptor(mock_stream)
|
||||
|
||||
# Initialize the global logs
|
||||
import app.logger
|
||||
from collections import deque
|
||||
app.logger.logs = deque(maxlen=100)
|
||||
|
||||
# Write a message
|
||||
interceptor.write("test message")
|
||||
|
||||
# Check that it was added to _logs_since_flush
|
||||
assert len(interceptor._logs_since_flush) == 1
|
||||
assert interceptor._logs_since_flush[0]["m"] == "test message"
|
||||
|
||||
def test_write_enforces_max_pending_logs(self):
|
||||
"""Test that write() enforces MAX_PENDING_LOGS to prevent OOM."""
|
||||
from app.logger import LogInterceptor
|
||||
|
||||
class MockStream:
|
||||
def __init__(self):
|
||||
self._buffer = io.BytesIO()
|
||||
self.encoding = 'utf-8'
|
||||
self.line_buffering = False
|
||||
|
||||
@property
|
||||
def buffer(self):
|
||||
return self._buffer
|
||||
|
||||
mock_stream = MockStream()
|
||||
interceptor = LogInterceptor(mock_stream)
|
||||
|
||||
# Initialize the global logs
|
||||
import app.logger
|
||||
from collections import deque
|
||||
app.logger.logs = deque(maxlen=100)
|
||||
|
||||
# Manually set _logs_since_flush to be at the limit
|
||||
interceptor._logs_since_flush = [
|
||||
{"t": "test", "m": f"old_message_{i}"}
|
||||
for i in range(LogInterceptor.MAX_PENDING_LOGS)
|
||||
]
|
||||
|
||||
# Write one more message - should trigger trimming
|
||||
interceptor.write("new_message")
|
||||
|
||||
# Should still be at MAX_PENDING_LOGS, oldest dropped
|
||||
assert len(interceptor._logs_since_flush) == LogInterceptor.MAX_PENDING_LOGS
|
||||
# The new message should be at the end
|
||||
assert interceptor._logs_since_flush[-1]["m"] == "new_message"
|
||||
# The oldest message should have been dropped (old_message_0)
|
||||
assert interceptor._logs_since_flush[0]["m"] == "old_message_1"
|
||||
Loading…
Reference in New Issue
Block a user