Merge branch 'master' into feat/api-nodes/api-client-v2

This commit is contained in:
bigcat88 2025-10-17 16:45:36 +03:00
commit 50f6a5e10d
No known key found for this signature in database
GPG Key ID: 1F0BF0EC3CF22721
4 changed files with 49 additions and 5 deletions

View File

@ -52,6 +52,16 @@ try:
except (ModuleNotFoundError, TypeError):
logging.warning("Could not set sdpa backend priority.")
NVIDIA_MEMORY_CONV_BUG_WORKAROUND = False
try:
if comfy.model_management.is_nvidia():
if torch.backends.cudnn.version() >= 91200 and comfy.model_management.torch_version_numeric >= (2, 9) and comfy.model_management.torch_version_numeric <= (2, 10):
#TODO: change upper bound version once it's fixed'
NVIDIA_MEMORY_CONV_BUG_WORKAROUND = True
logging.info("working around nvidia conv3d memory bug.")
except:
pass
cast_to = comfy.model_management.cast_to #TODO: remove once no more references
if torch.cuda.is_available() and torch.backends.cudnn.is_available() and PerformanceFeature.AutoTune in args.fast:
@ -151,6 +161,15 @@ class disable_weight_init:
def reset_parameters(self):
return None
def _conv_forward(self, input, weight, bias, *args, **kwargs):
if NVIDIA_MEMORY_CONV_BUG_WORKAROUND and weight.dtype in (torch.float16, torch.bfloat16):
out = torch.cudnn_convolution(input, weight, self.padding, self.stride, self.dilation, self.groups, benchmark=False, deterministic=False, allow_tf32=True)
if bias is not None:
out += bias.reshape((1, -1) + (1,) * (out.ndim - 2))
return out
else:
return super()._conv_forward(input, weight, bias, *args, **kwargs)
def forward_comfy_cast_weights(self, input):
weight, bias = cast_bias_weight(self, input)
return self._conv_forward(input, weight, bias)

View File

@ -390,7 +390,9 @@ class Veo3VideoGenerationNode(VeoVideoGenerationNode):
),
IO.Combo.Input(
"model",
options=list(MODELS_MAP.keys()),
options=[
"veo-3.1-generate", "veo-3.1-fast-generate", "veo-3.0-generate-001", "veo-3.0-fast-generate-001"
],
default="veo-3.0-generate-001",
tooltip="Veo 3 model to use for video generation",
optional=True,

View File

@ -244,6 +244,8 @@ class EasyCacheHolder:
self.total_steps_skipped += 1
batch_offset = x.shape[0] // len(uuids)
for i, uuid in enumerate(uuids):
# slice out only what is relevant to this cond
batch_slice = [slice(i*batch_offset,(i+1)*batch_offset)]
# if cached dims don't match x dims, cut off excess and hope for the best (cosmos world2video)
if x.shape[1:] != self.uuid_cache_diffs[uuid].shape[1:]:
if not self.allow_mismatch:
@ -261,9 +263,8 @@ class EasyCacheHolder:
slicing.append(slice(None, dim_u))
else:
slicing.append(slice(None))
slicing = [slice(i*batch_offset,(i+1)*batch_offset)] + slicing
x = x[slicing]
x += self.uuid_cache_diffs[uuid].to(x.device)
batch_slice = batch_slice + slicing
x[batch_slice] += self.uuid_cache_diffs[uuid].to(x.device)
return x
def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor, uuids: list[UUID]):

View File

@ -48,6 +48,28 @@ async def send_socket_catch_exception(function, message):
except (aiohttp.ClientError, aiohttp.ClientPayloadError, ConnectionResetError, BrokenPipeError, ConnectionError) as err:
logging.warning("send error: {}".format(err))
# Track deprecated paths that have been warned about to only warn once per file
_deprecated_paths_warned = set()
@web.middleware
async def deprecation_warning(request: web.Request, handler):
"""Middleware to warn about deprecated frontend API paths"""
path = request.path
if (path.startswith('/scripts/') or path.startswith('/extensions/core/')):
# Only warn once per unique file path
if path not in _deprecated_paths_warned:
_deprecated_paths_warned.add(path)
logging.warning(
f"[DEPRECATION WARNING] Detected import of deprecated legacy API: {path}. "
f"This is likely caused by a custom node extension using outdated APIs. "
f"Please update your extensions or contact the extension author for an updated version."
)
response: web.Response = await handler(request)
return response
@web.middleware
async def compress_body(request: web.Request, handler):
accept_encoding = request.headers.get("Accept-Encoding", "")
@ -159,7 +181,7 @@ class PromptServer():
self.client_session:Optional[aiohttp.ClientSession] = None
self.number = 0
middlewares = [cache_control]
middlewares = [cache_control, deprecation_warning]
if args.enable_compress_response_body:
middlewares.append(compress_body)