use torch.load/save and fix bad behaviors

This commit is contained in:
Kohaku-Blueleaf 2025-11-22 16:29:55 +08:00
parent 71ac74ecb8
commit 09afeeb1c4

View File

@ -1,7 +1,7 @@
import logging
import os
import pickle
import math
import json
import numpy as np
import torch
@ -383,7 +383,7 @@ class ImageProcessingNode(io.ComfyNode):
is_group = cls._detect_processing_mode()
# Auto-detect is_output_list if not explicitly set
# Single processing: True (backend collects results into list)
# Single processing: False (backend collects results into list)
# Group processing: True by default (can be False for single-output nodes)
output_is_list = (
cls.is_output_list if cls.is_output_list is not None else is_group
@ -435,13 +435,7 @@ class ImageProcessingNode(io.ComfyNode):
# Individual processing: images is single item, call _process
result = cls._process(images, **params)
# Wrap result based on is_output_list
if cls.is_output_list:
# Result should already be a list (or will be for individual)
return io.NodeOutput(result if is_group else [result])
else:
# Single output - wrap in list for NodeOutput
return io.NodeOutput([result])
return io.NodeOutput(result)
@classmethod
def _process(cls, image, **kwargs):
@ -1395,7 +1389,7 @@ class SaveTrainingDataset(io.ComfyNode):
shard_path = os.path.join(output_dir, shard_filename)
with open(shard_path, "wb") as f:
pickle.dump(shard_data, f, protocol=pickle.HIGHEST_PROTOCOL)
torch.save(shard_data, f)
logging.info(
f"Saved shard {shard_idx + 1}/{num_shards}: {shard_filename} ({end_idx - start_idx} samples)"
@ -1409,8 +1403,6 @@ class SaveTrainingDataset(io.ComfyNode):
}
metadata_path = os.path.join(output_dir, "metadata.json")
with open(metadata_path, "w") as f:
import json
json.dump(metadata, f, indent=2)
logging.info(f"Successfully saved {num_samples} samples to {output_dir}.")
@ -1478,7 +1470,7 @@ class LoadTrainingDataset(io.ComfyNode):
shard_path = os.path.join(dataset_dir, shard_file)
with open(shard_path, "rb") as f:
shard_data = pickle.load(f)
shard_data = torch.load(f)
all_latents.extend(shard_data["latents"])
all_conditioning.extend(shard_data["conditioning"])