mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-24 05:10:18 +08:00
use torch.load/save and fix bad behaviors
This commit is contained in:
parent
71ac74ecb8
commit
09afeeb1c4
@ -1,7 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import pickle
|
|
||||||
import math
|
import math
|
||||||
|
import json
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -383,7 +383,7 @@ class ImageProcessingNode(io.ComfyNode):
|
|||||||
is_group = cls._detect_processing_mode()
|
is_group = cls._detect_processing_mode()
|
||||||
|
|
||||||
# Auto-detect is_output_list if not explicitly set
|
# Auto-detect is_output_list if not explicitly set
|
||||||
# Single processing: True (backend collects results into list)
|
# Single processing: False (backend collects results into list)
|
||||||
# Group processing: True by default (can be False for single-output nodes)
|
# Group processing: True by default (can be False for single-output nodes)
|
||||||
output_is_list = (
|
output_is_list = (
|
||||||
cls.is_output_list if cls.is_output_list is not None else is_group
|
cls.is_output_list if cls.is_output_list is not None else is_group
|
||||||
@ -435,13 +435,7 @@ class ImageProcessingNode(io.ComfyNode):
|
|||||||
# Individual processing: images is single item, call _process
|
# Individual processing: images is single item, call _process
|
||||||
result = cls._process(images, **params)
|
result = cls._process(images, **params)
|
||||||
|
|
||||||
# Wrap result based on is_output_list
|
return io.NodeOutput(result)
|
||||||
if cls.is_output_list:
|
|
||||||
# Result should already be a list (or will be for individual)
|
|
||||||
return io.NodeOutput(result if is_group else [result])
|
|
||||||
else:
|
|
||||||
# Single output - wrap in list for NodeOutput
|
|
||||||
return io.NodeOutput([result])
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _process(cls, image, **kwargs):
|
def _process(cls, image, **kwargs):
|
||||||
@ -1395,7 +1389,7 @@ class SaveTrainingDataset(io.ComfyNode):
|
|||||||
shard_path = os.path.join(output_dir, shard_filename)
|
shard_path = os.path.join(output_dir, shard_filename)
|
||||||
|
|
||||||
with open(shard_path, "wb") as f:
|
with open(shard_path, "wb") as f:
|
||||||
pickle.dump(shard_data, f, protocol=pickle.HIGHEST_PROTOCOL)
|
torch.save(shard_data, f)
|
||||||
|
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Saved shard {shard_idx + 1}/{num_shards}: {shard_filename} ({end_idx - start_idx} samples)"
|
f"Saved shard {shard_idx + 1}/{num_shards}: {shard_filename} ({end_idx - start_idx} samples)"
|
||||||
@ -1409,8 +1403,6 @@ class SaveTrainingDataset(io.ComfyNode):
|
|||||||
}
|
}
|
||||||
metadata_path = os.path.join(output_dir, "metadata.json")
|
metadata_path = os.path.join(output_dir, "metadata.json")
|
||||||
with open(metadata_path, "w") as f:
|
with open(metadata_path, "w") as f:
|
||||||
import json
|
|
||||||
|
|
||||||
json.dump(metadata, f, indent=2)
|
json.dump(metadata, f, indent=2)
|
||||||
|
|
||||||
logging.info(f"Successfully saved {num_samples} samples to {output_dir}.")
|
logging.info(f"Successfully saved {num_samples} samples to {output_dir}.")
|
||||||
@ -1478,7 +1470,7 @@ class LoadTrainingDataset(io.ComfyNode):
|
|||||||
shard_path = os.path.join(dataset_dir, shard_file)
|
shard_path = os.path.join(dataset_dir, shard_file)
|
||||||
|
|
||||||
with open(shard_path, "rb") as f:
|
with open(shard_path, "rb") as f:
|
||||||
shard_data = pickle.load(f)
|
shard_data = torch.load(f)
|
||||||
|
|
||||||
all_latents.extend(shard_data["latents"])
|
all_latents.extend(shard_data["latents"])
|
||||||
all_conditioning.extend(shard_data["conditioning"])
|
all_conditioning.extend(shard_data["conditioning"])
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user