mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-12 12:47:45 +08:00
Remove Chroma Radiance image conversion and stub VAE nodes
Add a chroma_radiance option to the VAELoader builtin node which uses comfy.sd.PixelspaceConversionVAE Add a PixelspaceConversionVAE to comfy.sd for converting BHWC 0..1 <-> BCHW -1..1
This commit is contained in:
parent
0828916ef1
commit
1bc45d3c0a
60
comfy/sd.py
60
comfy/sd.py
@ -785,6 +785,66 @@ class VAE:
|
|||||||
except:
|
except:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
# "Fake" VAE that converts from IMAGE B, H, W, C and values on the scale of 0..1
|
||||||
|
# to LATENT B, C, H, W and values on the scale of -1..1.
|
||||||
|
class PixelspaceConversionVAE:
|
||||||
|
def __init__(self, size_increment: int=16):
|
||||||
|
self.intermediate_device = comfy.model_management.intermediate_device()
|
||||||
|
self.size_increment = size_increment
|
||||||
|
|
||||||
|
def vae_encode_crop_pixels(self, pixels: torch.Tensor) -> torch.Tensor:
|
||||||
|
if self.size_increment == 1:
|
||||||
|
return pixels
|
||||||
|
dims = pixels.shape[1:-1]
|
||||||
|
for d in range(len(dims)):
|
||||||
|
d_adj = (dims[d] // self.size_increment) * self.size_increment
|
||||||
|
if d_adj == d:
|
||||||
|
continue
|
||||||
|
d_offset = (dims[d] % self.size_increment) // 2
|
||||||
|
pixels = pixels.narrow(d + 1, d_offset, d_adj)
|
||||||
|
return pixels
|
||||||
|
|
||||||
|
def encode(self, pixels: torch.Tensor, *_args, **_kwargs) -> torch.Tensor:
|
||||||
|
if pixels.ndim == 3:
|
||||||
|
pixels = pixels.unsqueeze(0)
|
||||||
|
elif pixels.ndim != 4:
|
||||||
|
raise ValueError("Unexpected input image shape")
|
||||||
|
# Ensure the image has spatial dimensions that are multiples of 16.
|
||||||
|
pixels = self.vae_encode_crop_pixels(pixels)
|
||||||
|
h, w, c = pixels.shape[1:]
|
||||||
|
if h < self.size_increment or w < self.size_increment:
|
||||||
|
raise ValueError(f"Image inputs must have height/width of at least {self.size_increment} pixel(s).")
|
||||||
|
pixels= pixels[..., :3]
|
||||||
|
if c == 1:
|
||||||
|
pixels = pixels.expand(-1, -1, -1, 3)
|
||||||
|
elif c != 3:
|
||||||
|
raise ValueError("Unexpected number of channels in input image")
|
||||||
|
# Rescale to -1..1 and move the channel dimension to position 1.
|
||||||
|
latent = pixels.to(device=self.intermediate_device, dtype=torch.float32, copy=True)
|
||||||
|
latent = latent.clamp_(0, 1).movedim(-1, 1).contiguous()
|
||||||
|
latent -= 0.5
|
||||||
|
latent *= 2
|
||||||
|
return latent.clamp_(-1, 1)
|
||||||
|
|
||||||
|
def decode(self, samples: torch.Tensor, *_args, **_kwargs) -> torch.Tensor:
|
||||||
|
# Rescale to 0..1 and move the channel dimension to the end.
|
||||||
|
img = samples.to(device=self.intermediate_device, dtype=torch.float32, copy=True)
|
||||||
|
img = img.clamp_(-1, 1).movedim(1, -1).contiguous()
|
||||||
|
img += 1.0
|
||||||
|
img *= 0.5
|
||||||
|
return img.clamp_(0, 1)
|
||||||
|
|
||||||
|
encode_tiled = encode
|
||||||
|
decode_tiled = decode
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def spacial_compression_decode(cls) -> int:
|
||||||
|
# This just exists so the tiled VAE nodes don't crash.
|
||||||
|
return 1
|
||||||
|
|
||||||
|
spacial_compression_encode = spacial_compression_decode
|
||||||
|
temporal_compression_decode = spacial_compression_decode
|
||||||
|
|
||||||
class StyleModel:
|
class StyleModel:
|
||||||
def __init__(self, model, device="cpu"):
|
def __init__(self, model, device="cpu"):
|
||||||
self.model = model
|
self.model = model
|
||||||
|
|||||||
@ -28,110 +28,6 @@ class EmptyChromaRadianceLatentImage(io.ComfyNode):
|
|||||||
return io.NodeOutput({"samples":latent})
|
return io.NodeOutput({"samples":latent})
|
||||||
|
|
||||||
|
|
||||||
class ChromaRadianceStubVAE:
|
|
||||||
@staticmethod
|
|
||||||
def vae_encode_crop_pixels(pixels: torch.Tensor) -> torch.Tensor:
|
|
||||||
dims = pixels.shape[1:-1]
|
|
||||||
for d in range(len(dims)):
|
|
||||||
d_adj = (dims[d] // 16) * 16
|
|
||||||
if d_adj == d:
|
|
||||||
continue
|
|
||||||
d_offset = (dims[d] % 16) // 2
|
|
||||||
pixels = pixels.narrow(d + 1, d_offset, d_adj)
|
|
||||||
return pixels
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def encode(cls, pixels: torch.Tensor, *_args, **_kwargs) -> torch.Tensor:
|
|
||||||
device = comfy.model_management.intermediate_device()
|
|
||||||
if pixels.ndim == 3:
|
|
||||||
pixels = pixels.unsqueeze(0)
|
|
||||||
elif pixels.ndim != 4:
|
|
||||||
raise ValueError("Unexpected input image shape")
|
|
||||||
# Ensure the image has spatial dimensions that are multiples of 16.
|
|
||||||
pixels = cls.vae_encode_crop_pixels(pixels)
|
|
||||||
h, w, c = pixels.shape[1:]
|
|
||||||
if h < 16 or w < 16:
|
|
||||||
raise ValueError("Chroma Radiance image inputs must have height/width of at least 16 pixels.")
|
|
||||||
pixels= pixels[..., :3]
|
|
||||||
if c == 1:
|
|
||||||
pixels = pixels.expand(-1, -1, -1, 3)
|
|
||||||
elif c != 3:
|
|
||||||
raise ValueError("Unexpected number of channels in input image")
|
|
||||||
# Rescale to -1..1 and move the channel dimension to position 1.
|
|
||||||
latent = pixels.to(device=device, dtype=torch.float32, copy=True)
|
|
||||||
latent = latent.clamp_(0, 1).movedim(-1, 1).contiguous()
|
|
||||||
latent -= 0.5
|
|
||||||
latent *= 2
|
|
||||||
return latent.clamp_(-1, 1)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def decode(cls, samples: torch.Tensor, *_args, **_kwargs) -> torch.Tensor:
|
|
||||||
device = comfy.model_management.intermediate_device()
|
|
||||||
# Rescale to 0..1 and move the channel dimension to the end.
|
|
||||||
img = samples.to(device=device, dtype=torch.float32, copy=True)
|
|
||||||
img = img.clamp_(-1, 1).movedim(1, -1).contiguous()
|
|
||||||
img += 1.0
|
|
||||||
img *= 0.5
|
|
||||||
return img.clamp_(0, 1)
|
|
||||||
|
|
||||||
encode_tiled = encode
|
|
||||||
decode_tiled = decode
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def spacial_compression_decode(cls) -> int:
|
|
||||||
# This just exists so the tiled VAE nodes don't crash.
|
|
||||||
return 1
|
|
||||||
|
|
||||||
spacial_compression_encode = spacial_compression_decode
|
|
||||||
temporal_compression_decode = spacial_compression_decode
|
|
||||||
|
|
||||||
|
|
||||||
class ChromaRadianceLatentToImage(io.ComfyNode):
|
|
||||||
@classmethod
|
|
||||||
def define_schema(cls) -> io.Schema:
|
|
||||||
return io.Schema(
|
|
||||||
node_id="ChromaRadianceLatentToImage",
|
|
||||||
category="latent/chroma_radiance",
|
|
||||||
description="For use with Chroma Radiance. Converts an input LATENT to IMAGE.",
|
|
||||||
inputs=[io.Latent.Input(id="latent")],
|
|
||||||
outputs=[io.Image.Output()],
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def execute(cls, *, latent: dict) -> io.NodeOutput:
|
|
||||||
return io.NodeOutput(ChromaRadianceStubVAE.decode(latent["samples"]))
|
|
||||||
|
|
||||||
|
|
||||||
class ChromaRadianceImageToLatent(io.ComfyNode):
|
|
||||||
@classmethod
|
|
||||||
def define_schema(cls) -> io.Schema:
|
|
||||||
return io.Schema(
|
|
||||||
node_id="ChromaRadianceImageToLatent",
|
|
||||||
category="latent/chroma_radiance",
|
|
||||||
description="For use with Chroma Radiance. Converts an input IMAGE to LATENT. Note: Radiance requires inputs with width/height that are multiples of 16 so your image will be cropped if necessary.",
|
|
||||||
inputs=[io.Image.Input(id="image")],
|
|
||||||
outputs=[io.Latent.Output()],
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def execute(cls, *, image: torch.Tensor) -> io.NodeOutput:
|
|
||||||
return io.NodeOutput({"samples": ChromaRadianceStubVAE.encode(image)})
|
|
||||||
|
|
||||||
|
|
||||||
class ChromaRadianceStubVAENode(io.ComfyNode):
|
|
||||||
@classmethod
|
|
||||||
def define_schema(cls) -> io.Schema:
|
|
||||||
return io.Schema(
|
|
||||||
node_id="ChromaRadianceStubVAE",
|
|
||||||
category="vae/chroma_radiance",
|
|
||||||
description="For use with Chroma Radiance. Allows converting between latent and image types with nodes that require a VAE input. Note: Chroma Radiance requires inputs with width/height that are multiples of 16 so your image will be cropped if necessary.",
|
|
||||||
outputs=[io.Vae.Output()],
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def execute(cls) -> io.NodeOutput:
|
|
||||||
return io.NodeOutput(ChromaRadianceStubVAE())
|
|
||||||
|
|
||||||
class ChromaRadianceOptions(io.ComfyNode):
|
class ChromaRadianceOptions(io.ComfyNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def define_schema(cls) -> io.Schema:
|
def define_schema(cls) -> io.Schema:
|
||||||
@ -210,9 +106,6 @@ class ChromaRadianceExtension(ComfyExtension):
|
|||||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||||
return [
|
return [
|
||||||
EmptyChromaRadianceLatentImage,
|
EmptyChromaRadianceLatentImage,
|
||||||
ChromaRadianceLatentToImage,
|
|
||||||
ChromaRadianceImageToLatent,
|
|
||||||
ChromaRadianceStubVAENode,
|
|
||||||
ChromaRadianceOptions,
|
ChromaRadianceOptions,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
5
nodes.py
5
nodes.py
@ -730,6 +730,7 @@ class VAELoader:
|
|||||||
vaes.append("taesd3")
|
vaes.append("taesd3")
|
||||||
if f1_taesd_dec and f1_taesd_enc:
|
if f1_taesd_dec and f1_taesd_enc:
|
||||||
vaes.append("taef1")
|
vaes.append("taef1")
|
||||||
|
vaes.append("chroma_radiance")
|
||||||
return vaes
|
return vaes
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -772,7 +773,9 @@ class VAELoader:
|
|||||||
|
|
||||||
#TODO: scale factor?
|
#TODO: scale factor?
|
||||||
def load_vae(self, vae_name):
|
def load_vae(self, vae_name):
|
||||||
if vae_name in ["taesd", "taesdxl", "taesd3", "taef1"]:
|
if vae_name == "chroma_radiance":
|
||||||
|
return (comfy.sd.PixelspaceConversionVAE(),)
|
||||||
|
elif vae_name in ["taesd", "taesdxl", "taesd3", "taef1"]:
|
||||||
sd = self.load_taesd(vae_name)
|
sd = self.load_taesd(vae_name)
|
||||||
else:
|
else:
|
||||||
vae_path = folder_paths.get_full_path_or_raise("vae", vae_name)
|
vae_path = folder_paths.get_full_path_or_raise("vae", vae_name)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user