From 09afeeb1c49218b3eb63323b7bd1d6f5ce89dd5c Mon Sep 17 00:00:00 2001 From: Kohaku-Blueleaf <59680068+KohakuBlueleaf@users.noreply.github.com> Date: Sat, 22 Nov 2025 16:29:55 +0800 Subject: [PATCH] use torch.load/save and fix bad behaviors --- comfy_extras/nodes_dataset.py | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/comfy_extras/nodes_dataset.py b/comfy_extras/nodes_dataset.py index 83d688088..b23867505 100644 --- a/comfy_extras/nodes_dataset.py +++ b/comfy_extras/nodes_dataset.py @@ -1,7 +1,7 @@ import logging import os -import pickle import math +import json import numpy as np import torch @@ -383,7 +383,7 @@ class ImageProcessingNode(io.ComfyNode): is_group = cls._detect_processing_mode() # Auto-detect is_output_list if not explicitly set - # Single processing: True (backend collects results into list) + # Single processing: False (backend collects results into list) # Group processing: True by default (can be False for single-output nodes) output_is_list = ( cls.is_output_list if cls.is_output_list is not None else is_group @@ -435,13 +435,7 @@ class ImageProcessingNode(io.ComfyNode): # Individual processing: images is single item, call _process result = cls._process(images, **params) - # Wrap result based on is_output_list - if cls.is_output_list: - # Result should already be a list (or will be for individual) - return io.NodeOutput(result if is_group else [result]) - else: - # Single output - wrap in list for NodeOutput - return io.NodeOutput([result]) + return io.NodeOutput(result) @classmethod def _process(cls, image, **kwargs): @@ -1395,7 +1389,7 @@ class SaveTrainingDataset(io.ComfyNode): shard_path = os.path.join(output_dir, shard_filename) with open(shard_path, "wb") as f: - pickle.dump(shard_data, f, protocol=pickle.HIGHEST_PROTOCOL) + torch.save(shard_data, f) logging.info( f"Saved shard {shard_idx + 1}/{num_shards}: {shard_filename} ({end_idx - start_idx} samples)" @@ -1409,8 +1403,6 @@ class SaveTrainingDataset(io.ComfyNode): } metadata_path = os.path.join(output_dir, "metadata.json") with open(metadata_path, "w") as f: - import json - json.dump(metadata, f, indent=2) logging.info(f"Successfully saved {num_samples} samples to {output_dir}.") @@ -1478,7 +1470,7 @@ class LoadTrainingDataset(io.ComfyNode): shard_path = os.path.join(dataset_dir, shard_file) with open(shard_path, "rb") as f: - shard_data = pickle.load(f) + shard_data = torch.load(f) all_latents.extend(shard_data["latents"]) all_conditioning.extend(shard_data["conditioning"])