mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-23 04:40:15 +08:00
Merge branch 'comfyanonymous:master' into master
This commit is contained in:
commit
45d8f6570b
@ -853,19 +853,18 @@ def reshape_mask(input_mask, output_shape):
|
|||||||
dims = len(output_shape) - 2
|
dims = len(output_shape) - 2
|
||||||
|
|
||||||
if dims == 1:
|
if dims == 1:
|
||||||
mask = input_mask
|
|
||||||
scale_mode = "linear"
|
scale_mode = "linear"
|
||||||
|
|
||||||
if dims == 2:
|
if dims == 2:
|
||||||
mask = input_mask.reshape((-1, 1, input_mask.shape[-2], input_mask.shape[-1]))
|
input_mask = input_mask.reshape((-1, 1, input_mask.shape[-2], input_mask.shape[-1]))
|
||||||
scale_mode = "bilinear"
|
scale_mode = "bilinear"
|
||||||
|
|
||||||
if dims == 3:
|
if dims == 3:
|
||||||
if len(input_mask.shape) < 5:
|
if len(input_mask.shape) < 5:
|
||||||
mask = input_mask.reshape((1, 1, -1, input_mask.shape[-2], input_mask.shape[-1]))
|
input_mask = input_mask.reshape((1, 1, -1, input_mask.shape[-2], input_mask.shape[-1]))
|
||||||
scale_mode = "trilinear"
|
scale_mode = "trilinear"
|
||||||
|
|
||||||
mask = torch.nn.functional.interpolate(mask, size=output_shape[2:], mode=scale_mode)
|
mask = torch.nn.functional.interpolate(input_mask, size=output_shape[2:], mode=scale_mode)
|
||||||
if mask.shape[1] < output_shape[1]:
|
if mask.shape[1] < output_shape[1]:
|
||||||
mask = mask.repeat((1, output_shape[1]) + (1,) * dims)[:,:output_shape[1]]
|
mask = mask.repeat((1, output_shape[1]) + (1,) * dims)[:,:output_shape[1]]
|
||||||
mask = comfy.utils.repeat_to_batch_size(mask, output_shape[0])
|
mask = comfy.utils.repeat_to_batch_size(mask, output_shape[0])
|
||||||
|
|||||||
@ -26,8 +26,8 @@ class X0(comfy.model_sampling.EPS):
|
|||||||
class ModelSamplingDiscreteDistilled(comfy.model_sampling.ModelSamplingDiscrete):
|
class ModelSamplingDiscreteDistilled(comfy.model_sampling.ModelSamplingDiscrete):
|
||||||
original_timesteps = 50
|
original_timesteps = 50
|
||||||
|
|
||||||
def __init__(self, model_config=None):
|
def __init__(self, model_config=None, zsnr=None):
|
||||||
super().__init__(model_config)
|
super().__init__(model_config, zsnr=zsnr)
|
||||||
|
|
||||||
self.skip_steps = self.num_timesteps // self.original_timesteps
|
self.skip_steps = self.num_timesteps // self.original_timesteps
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user