diff --git a/comfy/ldm/flux/model.py b/comfy/ldm/flux/model.py index 8f4d99f54..c4de82795 100644 --- a/comfy/ldm/flux/model.py +++ b/comfy/ldm/flux/model.py @@ -224,19 +224,27 @@ class Flux(nn.Module): if ref_latents is not None: h = 0 w = 0 + index = 0 + index_ref_method = kwargs.get("ref_latents_method", "offset") == "index" for ref in ref_latents: - h_offset = 0 - w_offset = 0 - if ref.shape[-2] + h > ref.shape[-1] + w: - w_offset = w + if index_ref_method: + index += 1 + h_offset = 0 + w_offset = 0 else: - h_offset = h + index = 1 + h_offset = 0 + w_offset = 0 + if ref.shape[-2] + h > ref.shape[-1] + w: + w_offset = w + else: + h_offset = h + h = max(h, ref.shape[-2] + h_offset) + w = max(w, ref.shape[-1] + w_offset) - kontext, kontext_ids = self.process_img(ref, index=1, h_offset=h_offset, w_offset=w_offset) + kontext, kontext_ids = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset) img = torch.cat([img, kontext], dim=1) img_ids = torch.cat([img_ids, kontext_ids], dim=1) - h = max(h, ref.shape[-2] + h_offset) - w = max(w, ref.shape[-1] + w_offset) txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype) out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None)) diff --git a/comfy/model_base.py b/comfy/model_base.py index cde61df7c..bf874b875 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -890,6 +890,10 @@ class Flux(BaseModel): for lat in ref_latents: latents.append(self.process_latent_in(lat)) out['ref_latents'] = comfy.conds.CONDList(latents) + + ref_latents_method = kwargs.get("reference_latents_method", None) + if ref_latents_method is not None: + out['ref_latents_method'] = comfy.conds.CONDConstant(ref_latents_method) return out def extra_conds_shapes(self, **kwargs): diff --git a/comfy_api_nodes/nodes_openai.py b/comfy_api_nodes/nodes_openai.py index ab3c5363b..cbff2b2d2 100644 --- a/comfy_api_nodes/nodes_openai.py +++ b/comfy_api_nodes/nodes_openai.py @@ -464,8 +464,6 @@ class OpenAIGPTImage1(ComfyNodeABC): path = "/proxy/openai/images/generations" content_type = "application/json" request_class = OpenAIImageGenerationRequest - img_binaries = [] - mask_binary = None files = [] if image is not None: @@ -484,14 +482,11 @@ class OpenAIGPTImage1(ComfyNodeABC): img_byte_arr = io.BytesIO() img.save(img_byte_arr, format="PNG") img_byte_arr.seek(0) - img_binary = img_byte_arr - img_binary.name = f"image_{i}.png" - img_binaries.append(img_binary) if batch_size == 1: - files.append(("image", img_binary)) + files.append(("image", (f"image_{i}.png", img_byte_arr, "image/png"))) else: - files.append(("image[]", img_binary)) + files.append(("image[]", (f"image_{i}.png", img_byte_arr, "image/png"))) if mask is not None: if image is None: @@ -511,9 +506,7 @@ class OpenAIGPTImage1(ComfyNodeABC): mask_img_byte_arr = io.BytesIO() mask_img.save(mask_img_byte_arr, format="PNG") mask_img_byte_arr.seek(0) - mask_binary = mask_img_byte_arr - mask_binary.name = "mask.png" - files.append(("mask", mask_binary)) + files.append(("mask", ("mask.png", mask_img_byte_arr, "image/png"))) # Build the operation operation = SynchronousOperation( diff --git a/comfy_extras/nodes_flux.py b/comfy_extras/nodes_flux.py index 8a8a17698..c8db75bb3 100644 --- a/comfy_extras/nodes_flux.py +++ b/comfy_extras/nodes_flux.py @@ -100,9 +100,28 @@ class FluxKontextImageScale: return (image, ) +class FluxKontextMultiReferenceLatentMethod: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "conditioning": ("CONDITIONING", ), + "reference_latents_method": (("offset", "index"), ), + }} + + RETURN_TYPES = ("CONDITIONING",) + FUNCTION = "append" + EXPERIMENTAL = True + + CATEGORY = "advanced/conditioning/flux" + + def append(self, conditioning, reference_latents_method): + c = node_helpers.conditioning_set_values(conditioning, {"reference_latents_method": reference_latents_method}) + return (c, ) + NODE_CLASS_MAPPINGS = { "CLIPTextEncodeFlux": CLIPTextEncodeFlux, "FluxGuidance": FluxGuidance, "FluxDisableGuidance": FluxDisableGuidance, "FluxKontextImageScale": FluxKontextImageScale, + "FluxKontextMultiReferenceLatentMethod": FluxKontextMultiReferenceLatentMethod, }