diff --git a/comfy_extras/nodes_canny.py b/comfy_extras/nodes_canny.py index 5e7c4eabb..648b4279d 100644 --- a/comfy_extras/nodes_canny.py +++ b/comfy_extras/nodes_canny.py @@ -3,6 +3,7 @@ from typing_extensions import override import comfy.model_management from comfy_api.latest import ComfyExtension, io +import torch class Canny(io.ComfyNode): @@ -29,8 +30,8 @@ class Canny(io.ComfyNode): @classmethod def execute(cls, image, low_threshold, high_threshold) -> io.NodeOutput: - output = canny(image.to(comfy.model_management.get_torch_device()).movedim(-1, 1), low_threshold, high_threshold) - img_out = output[1].to(comfy.model_management.intermediate_device()).repeat(1, 3, 1, 1).movedim(1, -1) + output = canny(image.to(device=comfy.model_management.get_torch_device(), dtype=torch.float32).movedim(-1, 1), low_threshold, high_threshold) + img_out = output[1].to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype()).repeat(1, 3, 1, 1).movedim(1, -1) return io.NodeOutput(img_out)