diff --git a/.github/workflows/windows_release_dependencies.yml b/.github/workflows/windows_release_dependencies.yml index 6428891ee..85e6a52fd 100644 --- a/.github/workflows/windows_release_dependencies.yml +++ b/.github/workflows/windows_release_dependencies.yml @@ -12,7 +12,7 @@ on: description: 'extra dependencies' required: false type: string - default: "\"numpy<2\"" + default: "" cu: description: 'cuda version' required: true diff --git a/comfy/model_management.py b/comfy/model_management.py index 93fcbb641..5b7dd0981 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -854,27 +854,21 @@ def force_channels_last(): #TODO return False +def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False): + if device is None or weight.device == device: + if not copy: + if dtype is None or weight.dtype == dtype: + return weight + return weight.to(dtype=dtype, copy=copy) + + r = torch.empty_like(weight, dtype=dtype, device=device) + r.copy_(weight, non_blocking=non_blocking) + return r + def cast_to_device(tensor, device, dtype, copy=False): - device_supports_cast = False - if tensor.dtype == torch.float32 or tensor.dtype == torch.float16: - device_supports_cast = True - elif tensor.dtype == torch.bfloat16: - if hasattr(device, 'type') and device.type.startswith("cuda"): - device_supports_cast = True - elif is_intel_xpu(): - device_supports_cast = True + non_blocking = device_supports_non_blocking(device) + return cast_to(tensor, dtype=dtype, device=device, non_blocking=non_blocking, copy=copy) - non_blocking = device_should_use_non_blocking(device) - - if device_supports_cast: - if copy: - if tensor.device == device: - return tensor.to(dtype, copy=copy, non_blocking=non_blocking) - return tensor.to(device, copy=copy, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking) - else: - return tensor.to(device, non_blocking=non_blocking).to(dtype, non_blocking=non_blocking) - else: - return tensor.to(device, dtype, copy=copy, non_blocking=non_blocking) def xformers_enabled(): global directml_enabled diff --git a/comfy/ops.py b/comfy/ops.py index c90e25ead..a8bfe1ea7 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -20,19 +20,10 @@ import torch import comfy.model_management from comfy.cli_args import args -def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False): - if device is None or weight.device == device: - if not copy: - if dtype is None or weight.dtype == dtype: - return weight - return weight.to(dtype=dtype, copy=copy) - - r = torch.empty_like(weight, dtype=dtype, device=device) - r.copy_(weight, non_blocking=non_blocking) - return r +cast_to = comfy.model_management.cast_to #TODO: remove once no more references def cast_to_input(weight, input, non_blocking=False, copy=True): - return cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy) + return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy) def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None): if input is not None: @@ -47,12 +38,12 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None): non_blocking = comfy.model_management.device_supports_non_blocking(device) if s.bias is not None: has_function = s.bias_function is not None - bias = cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function) + bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function) if has_function: bias = s.bias_function(bias) has_function = s.weight_function is not None - weight = cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function) + weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function) if has_function: weight = s.weight_function(weight) return weight, bias diff --git a/server.py b/server.py index c7bf6622d..ada6d90c3 100644 --- a/server.py +++ b/server.py @@ -40,7 +40,7 @@ class BinaryEventTypes: async def send_socket_catch_exception(function, message): try: await function(message) - except (aiohttp.ClientError, aiohttp.ClientPayloadError, ConnectionResetError) as err: + except (aiohttp.ClientError, aiohttp.ClientPayloadError, ConnectionResetError, BrokenPipeError, ConnectionError) as err: logging.warning("send error: {}".format(err)) def get_comfyui_version():