Merge branch 'master' of github.com:comfyanonymous/ComfyUI

This commit is contained in:
doctorpangloss 2025-09-26 13:18:23 -07:00
commit 8ab28996fb
6 changed files with 365 additions and 39 deletions

View File

@ -1404,7 +1404,7 @@ class WanT2VCrossAttentionGather(WanSelfAttention):
x = optimized_attention(q, k, v, heads=self.num_heads, skip_reshape=True, skip_output_reshape=True, transformer_options=transformer_options) x = optimized_attention(q, k, v, heads=self.num_heads, skip_reshape=True, skip_output_reshape=True, transformer_options=transformer_options)
x = x.transpose(1, 2).view(b, -1, n, d).flatten(2) x = x.transpose(1, 2).reshape(b, -1, n * d)
x = self.o(x) x = self.o(x)
return x return x

View File

@ -737,7 +737,9 @@ def _load_models_gpu(models: Sequence[ModelManageable], memory_required: int = 0
if loaded_model.model.is_clone(current_loaded_models[i].model): if loaded_model.model.is_clone(current_loaded_models[i].model):
to_unload = [i] + to_unload to_unload = [i] + to_unload
for i in to_unload: for i in to_unload:
current_loaded_models.pop(i).model.detach(unpatch_all=False) model_to_unload = current_loaded_models.pop(i)
model_to_unload.model.detach(unpatch_all=False)
model_to_unload.model_finalizer.detach()
total_memory_required = {} total_memory_required = {}
for loaded_model in models_to_load: for loaded_model in models_to_load:

View File

@ -9,8 +9,9 @@ class Rodin3DGenerateRequest(BaseModel):
seed: int = Field(..., description="seed_") seed: int = Field(..., description="seed_")
tier: str = Field(..., description="Tier of generation.") tier: str = Field(..., description="Tier of generation.")
material: str = Field(..., description="The material type.") material: str = Field(..., description="The material type.")
quality: str = Field(..., description="The generation quality of the mesh.") quality_override: int = Field(..., description="The poly count of the mesh.")
mesh_mode: str = Field(..., description="It controls the type of faces of generated models.") mesh_mode: str = Field(..., description="It controls the type of faces of generated models.")
TAPose: Optional[bool] = Field(None, description="")
class GenerateJobsData(BaseModel): class GenerateJobsData(BaseModel):
uuids: List[str] = Field(..., description="str LIST") uuids: List[str] = Field(..., description="str LIST")

View File

@ -121,10 +121,10 @@ class Rodin3DAPI:
else: else:
return "Generating" return "Generating"
async def create_generate_task(self, images=None, seed=1, material="PBR", quality="medium", tier="Regular", mesh_mode="Quad", **kwargs): async def create_generate_task(self, images=None, seed=1, material="PBR", quality_override=18000, tier="Regular", mesh_mode="Quad", TAPose = False, **kwargs):
if images is None: if images is None:
raise Exception("Rodin 3D generate requires at least 1 image.") raise Exception("Rodin 3D generate requires at least 1 image.")
if len(images) >= 5: if len(images) > 5:
raise Exception("Rodin 3D generate requires up to 5 image.") raise Exception("Rodin 3D generate requires up to 5 image.")
path = "/proxy/rodin/api/v2/rodin" path = "/proxy/rodin/api/v2/rodin"
@ -139,8 +139,9 @@ class Rodin3DAPI:
seed=seed, seed=seed,
tier=tier, tier=tier,
material=material, material=material,
quality=quality, quality_override=quality_override,
mesh_mode=mesh_mode mesh_mode=mesh_mode,
TAPose=TAPose,
), ),
files=[ files=[
( (
@ -211,23 +212,36 @@ class Rodin3DAPI:
return await operation.execute() return await operation.execute()
def get_quality_mode(self, poly_count): def get_quality_mode(self, poly_count):
if poly_count == "200K-Triangle": polycount = poly_count.split("-")
poly = polycount[1]
count = polycount[0]
if poly == "Triangle":
mesh_mode = "Raw" mesh_mode = "Raw"
quality = "medium" elif poly == "Quad":
mesh_mode = "Quad"
else: else:
mesh_mode = "Quad" mesh_mode = "Quad"
if poly_count == "4K-Quad":
quality = "extra-low"
elif poly_count == "8K-Quad":
quality = "low"
elif poly_count == "18K-Quad":
quality = "medium"
elif poly_count == "50K-Quad":
quality = "high"
else:
quality = "medium"
return mesh_mode, quality if count == "4K":
quality_override = 4000
elif count == "8K":
quality_override = 8000
elif count == "18K":
quality_override = 18000
elif count == "50K":
quality_override = 50000
elif count == "2K":
quality_override = 2000
elif count == "20K":
quality_override = 20000
elif count == "150K":
quality_override = 150000
elif count == "500K":
quality_override = 500000
else:
quality_override = 18000
return mesh_mode, quality_override
async def download_files(self, url_list): async def download_files(self, url_list):
save_path = os.path.join(comfy_paths.get_output_directory(), "Rodin3D", datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")) save_path = os.path.join(comfy_paths.get_output_directory(), "Rodin3D", datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
@ -300,9 +314,9 @@ class Rodin3D_Regular(Rodin3DAPI):
m_images = [] m_images = []
for i in range(num_images): for i in range(num_images):
m_images.append(Images[i]) m_images.append(Images[i])
mesh_mode, quality = self.get_quality_mode(Polygon_count) mesh_mode, quality_override = self.get_quality_mode(Polygon_count)
task_uuid, subscription_key = await self.create_generate_task(images=m_images, seed=Seed, material=Material_Type, task_uuid, subscription_key = await self.create_generate_task(images=m_images, seed=Seed, material=Material_Type,
quality=quality, tier=tier, mesh_mode=mesh_mode, quality_override=quality_override, tier=tier, mesh_mode=mesh_mode,
**kwargs) **kwargs)
await self.poll_for_task_status(subscription_key, **kwargs) await self.poll_for_task_status(subscription_key, **kwargs)
download_list = await self.get_rodin_download_list(task_uuid, **kwargs) download_list = await self.get_rodin_download_list(task_uuid, **kwargs)
@ -346,9 +360,9 @@ class Rodin3D_Detail(Rodin3DAPI):
m_images = [] m_images = []
for i in range(num_images): for i in range(num_images):
m_images.append(Images[i]) m_images.append(Images[i])
mesh_mode, quality = self.get_quality_mode(Polygon_count) mesh_mode, quality_override = self.get_quality_mode(Polygon_count)
task_uuid, subscription_key = await self.create_generate_task(images=m_images, seed=Seed, material=Material_Type, task_uuid, subscription_key = await self.create_generate_task(images=m_images, seed=Seed, material=Material_Type,
quality=quality, tier=tier, mesh_mode=mesh_mode, quality_override=quality_override, tier=tier, mesh_mode=mesh_mode,
**kwargs) **kwargs)
await self.poll_for_task_status(subscription_key, **kwargs) await self.poll_for_task_status(subscription_key, **kwargs)
download_list = await self.get_rodin_download_list(task_uuid, **kwargs) download_list = await self.get_rodin_download_list(task_uuid, **kwargs)
@ -392,9 +406,9 @@ class Rodin3D_Smooth(Rodin3DAPI):
m_images = [] m_images = []
for i in range(num_images): for i in range(num_images):
m_images.append(Images[i]) m_images.append(Images[i])
mesh_mode, quality = self.get_quality_mode(Polygon_count) mesh_mode, quality_override = self.get_quality_mode(Polygon_count)
task_uuid, subscription_key = await self.create_generate_task(images=m_images, seed=Seed, material=Material_Type, task_uuid, subscription_key = await self.create_generate_task(images=m_images, seed=Seed, material=Material_Type,
quality=quality, tier=tier, mesh_mode=mesh_mode, quality_override=quality_override, tier=tier, mesh_mode=mesh_mode,
**kwargs) **kwargs)
await self.poll_for_task_status(subscription_key, **kwargs) await self.poll_for_task_status(subscription_key, **kwargs)
download_list = await self.get_rodin_download_list(task_uuid, **kwargs) download_list = await self.get_rodin_download_list(task_uuid, **kwargs)
@ -446,10 +460,10 @@ class Rodin3D_Sketch(Rodin3DAPI):
for i in range(num_images): for i in range(num_images):
m_images.append(Images[i]) m_images.append(Images[i])
material_type = "PBR" material_type = "PBR"
quality = "medium" quality_override = 18000
mesh_mode = "Quad" mesh_mode = "Quad"
task_uuid, subscription_key = await self.create_generate_task( task_uuid, subscription_key = await self.create_generate_task(
images=m_images, seed=Seed, material=material_type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs images=m_images, seed=Seed, material=material_type, quality_override=quality_override, tier=tier, mesh_mode=mesh_mode, **kwargs
) )
await self.poll_for_task_status(subscription_key, **kwargs) await self.poll_for_task_status(subscription_key, **kwargs)
download_list = await self.get_rodin_download_list(task_uuid, **kwargs) download_list = await self.get_rodin_download_list(task_uuid, **kwargs)
@ -457,6 +471,80 @@ class Rodin3D_Sketch(Rodin3DAPI):
return (model,) return (model,)
class Rodin3D_Gen2(Rodin3DAPI):
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"Images":
(
IO.IMAGE,
{
"forceInput":True,
}
)
},
"optional": {
"Seed": (
IO.INT,
{
"default":0,
"min":0,
"max":65535,
"display":"number"
}
),
"Material_Type": (
IO.COMBO,
{
"options": ["PBR", "Shaded"],
"default": "PBR"
}
),
"Polygon_count": (
IO.COMBO,
{
"options": ["4K-Quad", "8K-Quad", "18K-Quad", "50K-Quad", "2K-Triangle", "20K-Triangle", "150K-Triangle", "500K-Triangle"],
"default": "500K-Triangle"
}
),
"TAPose": (
IO.BOOLEAN,
{
"default": False,
}
)
},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
},
}
async def api_call(
self,
Images,
Seed,
Material_Type,
Polygon_count,
TAPose,
**kwargs
):
tier = "Gen-2"
num_images = Images.shape[0]
m_images = []
for i in range(num_images):
m_images.append(Images[i])
mesh_mode, quality_override = self.get_quality_mode(Polygon_count)
task_uuid, subscription_key = await self.create_generate_task(images=m_images, seed=Seed, material=Material_Type,
quality_override=quality_override, tier=tier, mesh_mode=mesh_mode, TAPose=TAPose,
**kwargs)
await self.poll_for_task_status(subscription_key, **kwargs)
download_list = await self.get_rodin_download_list(task_uuid, **kwargs)
model = await self.download_files(download_list)
return (model,)
# A dictionary that contains all nodes you want to export with their names # A dictionary that contains all nodes you want to export with their names
# NOTE: names should be globally unique # NOTE: names should be globally unique
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
@ -464,6 +552,7 @@ NODE_CLASS_MAPPINGS = {
"Rodin3D_Detail": Rodin3D_Detail, "Rodin3D_Detail": Rodin3D_Detail,
"Rodin3D_Smooth": Rodin3D_Smooth, "Rodin3D_Smooth": Rodin3D_Smooth,
"Rodin3D_Sketch": Rodin3D_Sketch, "Rodin3D_Sketch": Rodin3D_Sketch,
"Rodin3D_Gen2": Rodin3D_Gen2,
} }
# A dictionary that contains the friendly/humanly readable titles for the nodes # A dictionary that contains the friendly/humanly readable titles for the nodes
@ -472,4 +561,5 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"Rodin3D_Detail": "Rodin 3D Generate - Detail Generate", "Rodin3D_Detail": "Rodin 3D Generate - Detail Generate",
"Rodin3D_Smooth": "Rodin 3D Generate - Smooth Generate", "Rodin3D_Smooth": "Rodin 3D Generate - Smooth Generate",
"Rodin3D_Sketch": "Rodin 3D Generate - Sketch Generate", "Rodin3D_Sketch": "Rodin 3D Generate - Sketch Generate",
"Rodin3D_Gen2": "Rodin 3D Generate - Gen-2 Generate",
} }

View File

@ -11,6 +11,7 @@ import torch
import comfy.model_management import comfy.model_management
from comfy import node_helpers from comfy import node_helpers
import logging
from comfy.cli_args import args from comfy.cli_args import args
from comfy.cmd import folder_paths from comfy.cmd import folder_paths
from comfy.comfy_types import FileLocator from comfy.comfy_types import FileLocator
@ -379,6 +380,7 @@ class LoadAudio:
return "Invalid audio file: {}".format(audio) return "Invalid audio file: {}".format(audio)
return True return True
class RecordAudio: class RecordAudio:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
@ -386,7 +388,7 @@ class RecordAudio:
CATEGORY = "audio" CATEGORY = "audio"
RETURN_TYPES = ("AUDIO", ) RETURN_TYPES = ("AUDIO",)
FUNCTION = "load" FUNCTION = "load"
def load(self, audio): def load(self, audio):
@ -398,7 +400,223 @@ class RecordAudio:
waveform, sample_rate = torchaudio.load(audio_path) waveform, sample_rate = torchaudio.load(audio_path)
audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate} audio = {"waveform": waveform.unsqueeze(0), "sample_rate": sample_rate}
return (audio, ) return (audio,)
class TrimAudioDuration:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"audio": ("AUDIO",),
"start_index": ("FLOAT", {"default": 0.0, "min": -0xffffffffffffffff, "max": 0xffffffffffffffff, "step": 0.01, "tooltip": "Start time in seconds, can be negative to count from the end (supports sub-seconds)."}),
"duration": ("FLOAT", {"default": 60.0, "min": 0.0, "step": 0.01, "tooltip": "Duration in seconds"}),
},
}
FUNCTION = "trim"
RETURN_TYPES = ("AUDIO",)
CATEGORY = "audio"
DESCRIPTION = "Trim audio tensor into chosen time range."
def trim(self, audio, start_index, duration):
waveform = audio["waveform"]
sample_rate = audio["sample_rate"]
audio_length = waveform.shape[-1]
if start_index < 0:
start_frame = audio_length + int(round(start_index * sample_rate))
else:
start_frame = int(round(start_index * sample_rate))
start_frame = max(0, min(start_frame, audio_length - 1))
end_frame = start_frame + int(round(duration * sample_rate))
end_frame = max(0, min(end_frame, audio_length))
if start_frame >= end_frame:
raise ValueError("AudioTrim: Start time must be less than end time and be within the audio length.")
return ({"waveform": waveform[..., start_frame:end_frame], "sample_rate": sample_rate},)
class SplitAudioChannels:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"audio": ("AUDIO",),
}}
RETURN_TYPES = ("AUDIO", "AUDIO")
RETURN_NAMES = ("left", "right")
FUNCTION = "separate"
CATEGORY = "audio"
DESCRIPTION = "Separates the audio into left and right channels."
def separate(self, audio):
waveform = audio["waveform"]
sample_rate = audio["sample_rate"]
if waveform.shape[1] != 2:
raise ValueError("AudioSplit: Input audio has only one channel.")
left_channel = waveform[..., 0:1, :]
right_channel = waveform[..., 1:2, :]
return ({"waveform": left_channel, "sample_rate": sample_rate}, {"waveform": right_channel, "sample_rate": sample_rate})
def match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_2):
try:
import torchaudio # pylint: disable=import-error
except ImportError as exc_info:
raise TorchAudioNotFoundError()
if sample_rate_1 != sample_rate_2:
if sample_rate_1 > sample_rate_2:
waveform_2 = torchaudio.functional.resample(waveform_2, sample_rate_2, sample_rate_1)
output_sample_rate = sample_rate_1
logging.info(f"Resampling audio2 from {sample_rate_2}Hz to {sample_rate_1}Hz for merging.")
else:
waveform_1 = torchaudio.functional.resample(waveform_1, sample_rate_1, sample_rate_2)
output_sample_rate = sample_rate_2
logging.info(f"Resampling audio1 from {sample_rate_1}Hz to {sample_rate_2}Hz for merging.")
else:
output_sample_rate = sample_rate_1
return waveform_1, waveform_2, output_sample_rate
class AudioConcat:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"audio1": ("AUDIO",),
"audio2": ("AUDIO",),
"direction": (['after', 'before'], {"default": 'after', "tooltip": "Whether to append audio2 after or before audio1."}),
}}
RETURN_TYPES = ("AUDIO",)
FUNCTION = "concat"
CATEGORY = "audio"
DESCRIPTION = "Concatenates the audio1 to audio2 in the specified direction."
def concat(self, audio1, audio2, direction):
waveform_1 = audio1["waveform"]
waveform_2 = audio2["waveform"]
sample_rate_1 = audio1["sample_rate"]
sample_rate_2 = audio2["sample_rate"]
if waveform_1.shape[1] == 1:
waveform_1 = waveform_1.repeat(1, 2, 1)
logging.info("AudioConcat: Converted mono audio1 to stereo by duplicating the channel.")
if waveform_2.shape[1] == 1:
waveform_2 = waveform_2.repeat(1, 2, 1)
logging.info("AudioConcat: Converted mono audio2 to stereo by duplicating the channel.")
waveform_1, waveform_2, output_sample_rate = match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_2)
if direction == 'after':
concatenated_audio = torch.cat((waveform_1, waveform_2), dim=2)
elif direction == 'before':
concatenated_audio = torch.cat((waveform_2, waveform_1), dim=2)
return ({"waveform": concatenated_audio, "sample_rate": output_sample_rate},)
class AudioMerge:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"audio1": ("AUDIO",),
"audio2": ("AUDIO",),
"merge_method": (["add", "mean", "subtract", "multiply"], {"tooltip": "The method used to combine the audio waveforms."}),
},
}
FUNCTION = "merge"
RETURN_TYPES = ("AUDIO",)
CATEGORY = "audio"
DESCRIPTION = "Combine two audio tracks by overlaying their waveforms."
def merge(self, audio1, audio2, merge_method):
waveform_1 = audio1["waveform"]
waveform_2 = audio2["waveform"]
sample_rate_1 = audio1["sample_rate"]
sample_rate_2 = audio2["sample_rate"]
waveform_1, waveform_2, output_sample_rate = match_audio_sample_rates(waveform_1, sample_rate_1, waveform_2, sample_rate_2)
length_1 = waveform_1.shape[-1]
length_2 = waveform_2.shape[-1]
if length_2 > length_1:
logging.info(f"AudioMerge: Trimming audio2 from {length_2} to {length_1} samples to match audio1 length.")
waveform_2 = waveform_2[..., :length_1]
elif length_2 < length_1:
logging.info(f"AudioMerge: Padding audio2 from {length_2} to {length_1} samples to match audio1 length.")
pad_shape = list(waveform_2.shape)
pad_shape[-1] = length_1 - length_2
pad_tensor = torch.zeros(pad_shape, dtype=waveform_2.dtype, device=waveform_2.device)
waveform_2 = torch.cat((waveform_2, pad_tensor), dim=-1)
waveform = None
if merge_method == "add":
waveform = waveform_1 + waveform_2
elif merge_method == "subtract":
waveform = waveform_1 - waveform_2
elif merge_method == "multiply":
waveform = waveform_1 * waveform_2
elif merge_method == "mean":
waveform = (waveform_1 + waveform_2) / 2
max_val = waveform.abs().max()
if max_val > 1.0:
waveform = waveform / max_val
return ({"waveform": waveform, "sample_rate": output_sample_rate},)
class AudioAdjustVolume:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"audio": ("AUDIO",),
"volume": ("INT", {"default": 1.0, "min": -100, "max": 100, "tooltip": "Volume adjustment in decibels (dB). 0 = no change, +6 = double, -6 = half, etc"}),
}}
RETURN_TYPES = ("AUDIO",)
FUNCTION = "adjust_volume"
CATEGORY = "audio"
def adjust_volume(self, audio, volume):
if volume == 0:
return (audio,)
waveform = audio["waveform"]
sample_rate = audio["sample_rate"]
gain = 10 ** (volume / 20)
waveform = waveform * gain
return ({"waveform": waveform, "sample_rate": sample_rate},)
class EmptyAudio:
@classmethod
def INPUT_TYPES(s):
return {"required": {
"duration": ("FLOAT", {"default": 60.0, "min": 0.0, "max": 0xffffffffffffffff, "step": 0.01, "tooltip": "Duration of the empty audio clip in seconds"}),
"sample_rate": ("INT", {"default": 44100, "tooltip": "Sample rate of the empty audio clip."}),
"channels": ("INT", {"default": 2, "min": 1, "max": 2, "tooltip": "Number of audio channels (1 for mono, 2 for stereo)."}),
}}
RETURN_TYPES = ("AUDIO",)
FUNCTION = "create_empty_audio"
CATEGORY = "audio"
def create_empty_audio(self, duration, sample_rate, channels):
num_samples = int(round(duration * sample_rate))
waveform = torch.zeros((1, channels, num_samples), dtype=torch.float32)
return ({"waveform": waveform, "sample_rate": sample_rate},)
NODE_CLASS_MAPPINGS = { NODE_CLASS_MAPPINGS = {
@ -412,6 +630,12 @@ NODE_CLASS_MAPPINGS = {
"PreviewAudio": PreviewAudio, "PreviewAudio": PreviewAudio,
"ConditioningStableAudio": ConditioningStableAudio, "ConditioningStableAudio": ConditioningStableAudio,
"RecordAudio": RecordAudio, "RecordAudio": RecordAudio,
"TrimAudioDuration": TrimAudioDuration,
"SplitAudioChannels": SplitAudioChannels,
"AudioConcat": AudioConcat,
"AudioMerge": AudioMerge,
"AudioAdjustVolume": AudioAdjustVolume,
"EmptyAudio": EmptyAudio,
} }
NODE_DISPLAY_NAME_MAPPINGS = { NODE_DISPLAY_NAME_MAPPINGS = {
@ -424,4 +648,10 @@ NODE_DISPLAY_NAME_MAPPINGS = {
"SaveAudioMP3": "Save Audio (MP3)", "SaveAudioMP3": "Save Audio (MP3)",
"SaveAudioOpus": "Save Audio (Opus)", "SaveAudioOpus": "Save Audio (Opus)",
"RecordAudio": "Record Audio", "RecordAudio": "Record Audio",
"TrimAudioDuration": "Trim Audio Duration",
"SplitAudioChannels": "Split Audio Channels",
"AudioConcat": "Audio Concat",
"AudioMerge": "Audio Merge",
"AudioAdjustVolume": "Audio Adjust Volume",
"EmptyAudio": "Empty Audio",
} }

View File

@ -14,35 +14,38 @@ from comfy.nodes.common import MAX_RESOLUTION
def composite(destination, source, x, y, mask=None, multiplier=8, resize_source=False): def composite(destination, source, x, y, mask=None, multiplier=8, resize_source=False):
source = source.to(destination.device) source = source.to(destination.device)
if resize_source: if resize_source:
source = torch.nn.functional.interpolate(source, size=(destination.shape[2], destination.shape[3]), mode="bilinear") source = torch.nn.functional.interpolate(source, size=(destination.shape[-2], destination.shape[-1]), mode="bilinear")
source = utils.repeat_to_batch_size(source, destination.shape[0]) source = utils.repeat_to_batch_size(source, destination.shape[0])
x = max(-source.shape[3] * multiplier, min(x, destination.shape[3] * multiplier)) x = max(-source.shape[-1] * multiplier, min(x, destination.shape[-1] * multiplier))
y = max(-source.shape[2] * multiplier, min(y, destination.shape[2] * multiplier)) y = max(-source.shape[-2] * multiplier, min(y, destination.shape[-2] * multiplier))
left, top = (x // multiplier, y // multiplier) left, top = (x // multiplier, y // multiplier)
right, bottom = (left + source.shape[3], top + source.shape[2],) right, bottom = (left + source.shape[-1], top + source.shape[-2],)
if mask is None: if mask is None:
mask = torch.ones_like(source) mask = torch.ones_like(source)
else: else:
mask = mask.to(destination.device, copy=True) mask = mask.to(destination.device, copy=True)
mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(source.shape[2], source.shape[3]), mode="bilinear") mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(source.shape[-2], source.shape[-1]), mode="bilinear")
mask = utils.repeat_to_batch_size(mask, source.shape[0]) mask = utils.repeat_to_batch_size(mask, source.shape[0])
# calculate the bounds of the source that will be overlapping the destination # calculate the bounds of the source that will be overlapping the destination
# this prevents the source trying to overwrite latent pixels that are out of bounds # this prevents the source trying to overwrite latent pixels that are out of bounds
# of the destination # of the destination
visible_width, visible_height = (destination.shape[3] - left + min(0, x), destination.shape[2] - top + min(0, y),) visible_width, visible_height = (destination.shape[-1] - left + min(0, x), destination.shape[-2] - top + min(0, y),)
mask = mask[:, :, :visible_height, :visible_width] mask = mask[:, :, :visible_height, :visible_width]
if mask.ndim < source.ndim:
mask = mask.unsqueeze(1)
inverse_mask = torch.ones_like(mask) - mask inverse_mask = torch.ones_like(mask) - mask
source_portion = mask * source[:, :, :visible_height, :visible_width] source_portion = mask * source[..., :visible_height, :visible_width]
destination_portion = inverse_mask * destination[:, :, top:bottom, left:right] destination_portion = inverse_mask * destination[..., top:bottom, left:right]
destination[:, :, top:bottom, left:right] = source_portion + destination_portion destination[..., top:bottom, left:right] = source_portion + destination_portion
return destination return destination