Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Segmentation Ordering for Yolo segmentation #289

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,16 @@ def INPUT_TYPES(cls):
"model_name": (folder_paths.get_filename_list("yolov8"), ),
"index": ("INT", { "default": 0, "min": 0, "max": 256, "step": 1 }),
},
"optional": {
"sort_order": (["left-right", "right-left", "top-bottom", "bottom-top", "largest-smallest", "smallest-largest"], ),
}
}

CATEGORY = "SwarmUI/masks"
RETURN_TYPES = ("MASK",)
FUNCTION = "seg"

def seg(self, image, model_name, index):
def seg(self, image, model_name, index, sort_order="left-right"):
# TODO: Batch support?
i = 255.0 * image[0].cpu().numpy()
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
Expand All @@ -40,22 +43,46 @@ def seg(self, image, model_name, index):
else:
masks = masks.data.cpu()
masks = torch.nn.functional.interpolate(masks.unsqueeze(1), size=(image.shape[1], image.shape[2]), mode="bilinear").squeeze(1)

sortedindices = self.sort_masks(masks, sort_order)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is in the wrong place, it was in the correct place previously, please move it back

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am really sorry, but I am having a super dense moment here. I don't think I see it. Do you mean where it used to exist before, in the block

        if index == 0:
            result = masks[0]
            for i in range(1, len(masks)):
                result = torch.max(result, masks[i])
            return (result, )
        elif index > len(masks):
            return (torch.zeros_like(masks[0]), )
        else:
            sortedindices = []      ### Here?
            for mask in masks:
                sum_x = (torch.sum(mask, dim=0) != 0).to(dtype=torch.int)
                val = torch.argmax(sum_x).item()
                sortedindices.append(val)
            sortedindices = np.argsort(sortedindices)
            masks = masks[sortedindices]
            return (masks[index - 1].unsqueeze(0), )

Or where I had it in my previous commit where I was sorting boxes and ignoring masks?
I don't mean to make things difficult, I just have had very very few sleep with the newest phase the kiddo is in, so I would really appreciate a little pointer in the right direction.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where it was before your PR is the correct place for it to be. Index 0 and index longer than mask count do not require sorting.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, yes, that makes sense. Thank you. I'll fix this.

sortedindices = torch.tensor(np.ascontiguousarray(sortedindices))
masks = masks[sortedindices]

if index == 0:
result = masks[0]
for i in range(1, len(masks)):
result = torch.max(result, masks[i])
return (result, )
return (result.unsqueeze(0),)
elif index > len(masks):
return (torch.zeros_like(masks[0]), )
return (torch.zeros_like(masks[0]).unsqueeze(0),)
else:
return (masks[index - 1].unsqueeze(0),)

def sort_masks(self, masks, sort_order):
sortedindices = []
for mask in masks:
match sort_order:
case "left-right":
sum_x = (torch.sum(mask, dim=0) != 0).to(dtype=torch.int)
val = torch.argmax(sum_x).item()
case "right-left":
sum_x = (torch.sum(mask, dim=0) != 0).to(dtype=torch.int)
val = mask.shape[1] - torch.argmax(torch.flip(sum_x, [0])).item() - 1
case "top-bottom":
sum_y = (torch.sum(mask, dim=1) != 0).to(dtype=torch.int)
val = torch.argmax(sum_y).item()
case "bottom-top":
sum_y = (torch.sum(mask, dim=1) != 0).to(dtype=torch.int)
val = mask.shape[0] - torch.argmax(torch.flip(sum_y, [0])).item() - 1
case "largest-smallest" | "smallest-largest":
val = torch.sum(mask).item()
sortedindices.append(val)

sorted_indices_array = np.array(sortedindices)
if sort_order in ["right-left", "bottom-top", "largest-smallest"]:
return np.argsort(sorted_indices_array)[::-1].copy()
else:
sortedindices = []
for mask in masks:
sum_x = (torch.sum(mask, dim=0) != 0).to(dtype=torch.int)
val = torch.argmax(sum_x).item()
sortedindices.append(val)
sortedindices = np.argsort(sortedindices)
masks = masks[sortedindices]
return (masks[index - 1].unsqueeze(0), )
return np.argsort(sorted_indices_array)

NODE_CLASS_MAPPINGS = {
"SwarmYoloDetection": SwarmYoloDetection,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1054,11 +1054,13 @@ JArray doMaskShrinkApply(WorkflowGenerator g, JArray imgIn)
{
index = 0;
}
string sortOrder = g.UserInput.Get(T2IParamTypes.SegmentationSortOrder, "left-right");
segmentNode = g.CreateNode("SwarmYoloDetection", new JObject()
{
["image"] = g.FinalImageOut,
["model_name"] = fullname,
["index"] = index
["index"] = index,
["sort_order"] = sortOrder,
});
}
else
Expand Down
9 changes: 6 additions & 3 deletions src/Text2Image/T2IParamTypes.cs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ public enum ParamViewType
/// <param name="Subtype">The sub-type of the type - for models, this might be eg "Stable-Diffusion".</param>
/// <param name="ID">The raw ID of this parameter (will be set when registering).</param>
/// <param name="SharpType">The C# datatype.</param>
///
///
public record class T2IParamType(string Name, string Description, string Default, double Min = 0, double Max = 0, double Step = 1, double ViewMin = 0, double ViewMax = 0,
Func<string, string, string> Clean = null, Func<Session, List<string>> GetValues = null, string[] Examples = null, Func<List<string>, List<string>> ParseList = null, bool ValidateValues = true,
bool VisibleNormally = true, bool IsAdvanced = false, string FeatureFlag = null, string Permission = null, bool Toggleable = false, double OrderPriority = 10, T2IParamGroup Group = null, string IgnoreIf = null,
Expand Down Expand Up @@ -279,7 +279,7 @@ public static string ApplyStringEdit(string prior, string update)
return update;
}

public static T2IRegisteredParam<string> Prompt, NegativePrompt, AspectRatio, BackendType, RefinerMethod, FreeUApplyTo, FreeUVersion, PersonalNote, VideoFormat, VideoResolution, UnsamplerPrompt, ImageFormat, MaskBehavior, RawResolution, SeamlessTileable, SD3TextEncs, BitDepth, Webhooks;
public static T2IRegisteredParam<string> Prompt, NegativePrompt, AspectRatio, BackendType, RefinerMethod, FreeUApplyTo, FreeUVersion, PersonalNote, VideoFormat, VideoResolution, UnsamplerPrompt, ImageFormat, MaskBehavior, RawResolution, SeamlessTileable, SD3TextEncs, BitDepth, Webhooks, SegmentationSortOrder;
public static T2IRegisteredParam<int> Images, Steps, Width, Height, BatchSize, ExactBackendID, VAETileSize, ClipStopAtLayer, VideoFrames, VideoMotionBucket, VideoFPS, VideoSteps, RefinerSteps, CascadeLatentCompression, MaskShrinkGrow, MaskBlur, MaskGrow, SegmentMaskBlur, SegmentMaskGrow;
public static T2IRegisteredParam<long> Seed, VariationSeed, WildcardSeed;
public static T2IRegisteredParam<double> CFGScale, VariationSeedStrength, InitImageCreativity, InitImageResetToNorm, RefinerControl, RefinerUpscale, RefinerCFGScale, ReVisionStrength, AltResolutionHeightMult,
Expand All @@ -288,7 +288,7 @@ public static string ApplyStringEdit(string prior, string update)
public static T2IRegisteredParam<T2IModel> Model, RefinerModel, VAE, ReVisionModel, RegionalObjectInpaintingModel, SegmentModel, VideoModel, RefinerVAE;
public static T2IRegisteredParam<List<string>> Loras, LoraWeights, LoraSectionConfinement;
public static T2IRegisteredParam<List<Image>> PromptImages;
public static T2IRegisteredParam<bool> SaveIntermediateImages, DoNotSave, ControlNetPreviewOnly, RevisionZeroPrompt, RemoveBackground, NoSeedIncrement, NoPreviews, VideoBoomerang, ModelSpecificEnhancements, UseInpaintingEncode, MaskCompositeUnthresholded, SaveSegmentMask, InitImageRecompositeMask, UseReferenceOnly, RefinerDoTiling, AutomaticVAE, ZeroNegative;
public static T2IRegisteredParam<bool> SaveIntermediateImages, DoNotSave, ControlNetPreviewOnly, RevisionZeroPrompt, RemoveBackground, NoSeedIncrement, NoPreviews, VideoBoomerang, ModelSpecificEnhancements, UseInpaintingEncode, MaskCompositeUnthresholded, SaveSegmentMask, InitImageRecompositeMask, UseReferenceOnly, RefinerDoTiling, AutomaticVAE, ZeroNegative, ReverseSegmentationOrder;

public static T2IParamGroup GroupRevision, GroupCore, GroupVariation, GroupResolution, GroupSampling, GroupInitImage, GroupRefiners,
GroupAdvancedModelAddons, GroupSwarmInternal, GroupFreeU, GroupRegionalPrompting, GroupAdvancedSampling, GroupVideo;
Expand Down Expand Up @@ -634,6 +634,9 @@ static List<string> listVaes(Session s)
SegmentMaskGrow = Register<int>(new("Segment Mask Grow", "Number of pixels of grow the segment mask by.\nThis is for '<segment:>' syntax usage.\nDefaults to 16.",
"16", Min: 0, Max: 512, Group: GroupRegionalPrompting, Examples: ["0", "4", "8", "16", "32"], Toggleable: true, OrderPriority: 5
));
SegmentationSortOrder = Register<string>(new("Segmentation Sort Order", "How to sort segments when using '<segment:yolo->' syntax.\nleft-right, right-left, top-bottom, bottom-top, largest-smallest, or smallest-largest.\nYou can also use an index to specify a segment in the given order.\nExmaple: <segment:yolo-face_yolov9c.pt-2> when largest-smallest, will select the second largest face segment.",
"left-right", Toggleable: true, IgnoreIf: "left-right", GetValues: _ => ["left-right", "right-left", "top-bottom", "bottom-top", "largest-smallest", "smallest-largest"], Group: GroupRegionalPrompting, OrderPriority: 5.5
));
SegmentThresholdMax = Register<double>(new("Segment Threshold Max", "Maximum mask match value of a segment before clamping.\nLower values force more of the mask to be counted as maximum masking.\nToo-low values may include unwanted areas of the image.\nHigher values may soften the mask.",
"1", Min: 0.01, Max: 1, Step: 0.05, Toggleable: true, ViewType: ParamViewType.SLIDER, Group: GroupRegionalPrompting, OrderPriority: 6
));
Expand Down