diff --git a/examples/semantic_segmentation/camseq.py b/examples/semantic_segmentation/camseq.py new file mode 100644 index 00000000000..37898fd3f4a --- /dev/null +++ b/examples/semantic_segmentation/camseq.py @@ -0,0 +1,57 @@ +import logging +import os +import shutil + +import pandas as pd +import torch +import yaml +from torchvision.utils import save_image + +from ludwig.api import LudwigModel +from ludwig.datasets import camseq + +# clean out prior results +shutil.rmtree("./results", ignore_errors=True) + +# set up Python dictionary to hold model training parameters +with open("./config_camseq.yaml") as f: + config = yaml.safe_load(f.read()) + +# Define Ludwig model object that drive model training +model = LudwigModel(config, logging_level=logging.INFO) + +# load Camseq dataset +df = camseq.load(split=False) + +pred_set = df[0:1] # prediction hold-out 1 image +data_set = df[1:] # train,test,validate on remaining images + +# initiate model training +(train_stats, _, output_directory) = model.train( # training statistics # location for training results saved to disk + dataset=data_set, + experiment_name="simple_image_experiment", + model_name="single_model", + skip_save_processed_input=True, +) + +# print("{}".format(model.model)) + +# predict +pred_set.reset_index(inplace=True) +pred_out_df, results = model.predict(pred_set) + +if not isinstance(pred_out_df, pd.DataFrame): + pred_out_df = pred_out_df.compute() +pred_out_df["image_path"] = pred_set["image_path"] +pred_out_df["mask_path"] = pred_set["mask_path"] + +for index, row in pred_out_df.iterrows(): + pred_mask = torch.from_numpy(row["mask_path_predictions"]) + pred_mask_path = os.path.dirname(os.path.realpath(__file__)) + "/predicted_" + os.path.basename(row["mask_path"]) + print(f"\nSaving predicted mask to {pred_mask_path}") + if torch.any(pred_mask.gt(1)): + pred_mask = pred_mask.float() / 255 + save_image(pred_mask, pred_mask_path) + print("Input image_path: {}".format(row["image_path"])) + print("Label mask_path: {}".format(row["mask_path"])) + print(f"Predicted mask_path: {pred_mask_path}") diff --git a/examples/semantic_segmentation/config_camseq.yaml b/examples/semantic_segmentation/config_camseq.yaml new file mode 100644 index 00000000000..8e018a56e3f --- /dev/null +++ b/examples/semantic_segmentation/config_camseq.yaml @@ -0,0 +1,33 @@ +input_features: + - name: image_path + type: image + preprocessing: + num_processes: 6 + infer_image_max_height: 1024 + infer_image_max_width: 1024 + encoder: unet + +output_features: + - name: mask_path + type: image + preprocessing: + num_processes: 6 + infer_image_max_height: 1024 + infer_image_max_width: 1024 + infer_image_num_classes: true + num_classes: 32 + decoder: + type: unet + num_fc_layers: 0 + loss: + type: softmax_cross_entropy + +combiner: + type: concat + num_fc_layers: 0 + +trainer: + epochs: 100 + early_stop: -1 + batch_size: 1 + max_batch_size: 1 diff --git a/ludwig/constants.py b/ludwig/constants.py index 2ddc014a6c7..b963f0d6884 100644 --- a/ludwig/constants.py +++ b/ludwig/constants.py @@ -41,6 +41,8 @@ INFER_IMAGE_MAX_HEIGHT = "infer_image_max_height" INFER_IMAGE_MAX_WIDTH = "infer_image_max_width" INFER_IMAGE_SAMPLE_SIZE = "infer_image_sample_size" +INFER_IMAGE_NUM_CLASSES = "infer_image_num_classes" +IMAGE_MAX_CLASSES = 128 NUM_CLASSES = "num_classes" NUM_CHANNELS = "num_channels" REQUIRES_EQUAL_DIMENSIONS = "requires_equal_dimensions" diff --git a/ludwig/datasets/configs/camseq.yaml b/ludwig/datasets/configs/camseq.yaml new file mode 100644 index 00000000000..ce426b4f5ea --- /dev/null +++ b/ludwig/datasets/configs/camseq.yaml @@ -0,0 +1,21 @@ +version: 1.0 +name: camseq +kaggle_dataset_id: carlolepelaars/camseq-semantic-segmentation +archive_filenames: camseq-semantic-segmentation.zip +sha256: + camseq-semantic-segmentation.zip: ea3aeba2661d9b3e3ea406668e7d9240cb2ba0c7e374914bb6d866147faff502 +loader: camseq.CamseqLoader +preserve_paths: + - images + - masks +description: | + CamSeq01 Cambridge Labeled Objects in Video + https://www.kaggle.com/datasets/carlolepelaars/camseq-semantic-segmentation +columns: + - name: image_path + type: image + - name: mask_path + type: image +output_features: + - name: mask_path + type: image diff --git a/ludwig/datasets/loaders/camseq.py b/ludwig/datasets/loaders/camseq.py new file mode 100644 index 00000000000..d187fbd410e --- /dev/null +++ b/ludwig/datasets/loaders/camseq.py @@ -0,0 +1,61 @@ +# Copyright (c) 2023 Aizen Corp. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import os +from typing import List + +import pandas as pd + +from ludwig.datasets.loaders.dataset_loader import DatasetLoader +from ludwig.utils.fs_utils import makedirs + + +class CamseqLoader(DatasetLoader): + def transform_files(self, file_paths: List[str]) -> List[str]: + if not os.path.exists(self.processed_dataset_dir): + os.makedirs(self.processed_dataset_dir) + + # move images and masks into separate directories + source_dir = self.raw_dataset_dir + images_dir = os.path.join(source_dir, "images") + masks_dir = os.path.join(source_dir, "masks") + makedirs(images_dir, exist_ok=True) + makedirs(masks_dir, exist_ok=True) + + data_files = [] + for f in os.listdir(source_dir): + if f.endswith("_L.png"): # masks + dest_file = os.path.join(masks_dir, f) + elif f.endswith(".png"): # images + dest_file = os.path.join(images_dir, f) + else: + continue + source_file = os.path.join(source_dir, f) + os.replace(source_file, dest_file) + data_files.append(dest_file) + + return super().transform_files(data_files) + + def load_unprocessed_dataframe(self, file_paths: List[str]) -> pd.DataFrame: + """Creates a dataframe of image paths and mask paths.""" + images_dir = os.path.join(self.processed_dataset_dir, "images") + masks_dir = os.path.join(self.processed_dataset_dir, "masks") + images = [] + masks = [] + for f in os.listdir(images_dir): + images.append(os.path.join(images_dir, f)) + mask_f = f[:-4] + "_L.png" + masks.append(os.path.join(masks_dir, mask_f)) + + return pd.DataFrame({"image_path": images, "mask_path": masks}) diff --git a/ludwig/decoders/__init__.py b/ludwig/decoders/__init__.py index beee9b05e93..0257cdd7289 100644 --- a/ludwig/decoders/__init__.py +++ b/ludwig/decoders/__init__.py @@ -1,5 +1,6 @@ # register all decoders import ludwig.decoders.generic_decoders # noqa +import ludwig.decoders.image_decoders # noqa import ludwig.decoders.llm_decoders # noqa import ludwig.decoders.sequence_decoders # noqa import ludwig.decoders.sequence_tagger # noqa diff --git a/ludwig/decoders/image_decoders.py b/ludwig/decoders/image_decoders.py new file mode 100644 index 00000000000..aad1f2dd613 --- /dev/null +++ b/ludwig/decoders/image_decoders.py @@ -0,0 +1,91 @@ +#! /usr/bin/env python +# Copyright (c) 2023 Aizen Corp. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import logging +from typing import Dict, Optional, Type + +import torch + +from ludwig.api_annotations import DeveloperAPI +from ludwig.constants import ENCODER_OUTPUT_STATE, HIDDEN, IMAGE, LOGITS, PREDICTIONS +from ludwig.decoders.base import Decoder +from ludwig.decoders.registry import register_decoder +from ludwig.modules.convolutional_modules import UNetUpStack +from ludwig.schema.decoders.image_decoders import ImageDecoderConfig, UNetDecoderConfig + +logger = logging.getLogger(__name__) + + +@DeveloperAPI +@register_decoder("unet", IMAGE) +class UNetDecoder(Decoder): + def __init__( + self, + input_size: int, + height: int, + width: int, + num_channels: int = 1, + num_classes: int = 2, + conv_norm: Optional[str] = None, + decoder_config=None, + **kwargs, + ): + super().__init__() + self.config = decoder_config + self.num_classes = num_classes + + logger.debug(f" {self.name}") + if num_classes < 2: + raise ValueError(f"Invalid `num_classes` {num_classes} for unet decoder") + if height % 16 or width % 16: + raise ValueError(f"Invalid `height` {height} or `width` {width} for unet decoder") + + self.unet = UNetUpStack( + img_height=height, + img_width=width, + out_channels=num_classes, + norm=conv_norm, + ) + + self.input_reshape = list(self.unet.input_shape) + self.input_reshape.insert(0, -1) + self._output_shape = (height, width) + + def forward(self, combiner_outputs: Dict[str, torch.Tensor], target: torch.Tensor): + hidden = combiner_outputs[HIDDEN] + skips = combiner_outputs[ENCODER_OUTPUT_STATE] + + # unflatten combiner outputs + hidden = hidden.reshape(self.input_reshape) + + logits = self.unet(hidden, skips) + predictions = logits.argmax(dim=1).squeeze(1).byte() + + return {LOGITS: logits, PREDICTIONS: predictions} + + def get_prediction_set(self): + return {LOGITS, PREDICTIONS} + + @staticmethod + def get_schema_cls() -> Type[ImageDecoderConfig]: + return UNetDecoderConfig + + @property + def output_shape(self) -> torch.Size: + return torch.Size(self._output_shape) + + @property + def input_shape(self) -> torch.Size: + return self.unet.input_shape diff --git a/ludwig/encoders/image/base.py b/ludwig/encoders/image/base.py index 1b191451bc7..7b7688631b1 100644 --- a/ludwig/encoders/image/base.py +++ b/ludwig/encoders/image/base.py @@ -19,11 +19,11 @@ import torch from ludwig.api_annotations import DeveloperAPI -from ludwig.constants import ENCODER_OUTPUT, IMAGE +from ludwig.constants import ENCODER_OUTPUT, ENCODER_OUTPUT_STATE, IMAGE from ludwig.encoders.base import Encoder from ludwig.encoders.registry import register_encoder from ludwig.encoders.types import EncoderOutputDict -from ludwig.modules.convolutional_modules import Conv2DStack, ResNet +from ludwig.modules.convolutional_modules import Conv2DStack, ResNet, UNetDownStack from ludwig.modules.fully_connected_modules import FCStack from ludwig.modules.mlp_mixer_modules import MLPMixer from ludwig.schema.encoders.image.base import ( @@ -31,6 +31,7 @@ MLPMixerConfig, ResNetConfig, Stacked2DCNNConfig, + UNetEncoderConfig, ViTConfig, ) from ludwig.utils.torch_utils import FreezeModule @@ -424,3 +425,46 @@ def input_shape(self) -> torch.Size: @property def output_shape(self) -> torch.Size: return torch.Size(self._output_shape) + + +@DeveloperAPI +@register_encoder("unet", IMAGE) +class UNetEncoder(ImageEncoder): + def __init__( + self, + height: int, + width: int, + num_channels: int = 3, + conv_norm: Optional[str] = None, + encoder_config=None, + **kwargs, + ): + super().__init__() + self.config = encoder_config + + logger.debug(f" {self.name}") + if height % 16 or width % 16: + raise ValueError(f"Invalid `height` {height} or `width` {width} for unet encoder") + + self.unet = UNetDownStack( + img_height=height, + img_width=width, + in_channels=num_channels, + norm=conv_norm, + ) + + def forward(self, inputs: torch.Tensor) -> EncoderOutputDict: + hidden, skips = self.unet(inputs) + return {ENCODER_OUTPUT: hidden, ENCODER_OUTPUT_STATE: skips} + + @staticmethod + def get_schema_cls() -> Type[ImageEncoderConfig]: + return UNetEncoderConfig + + @property + def output_shape(self) -> torch.Size: + return self.unet.output_shape + + @property + def input_shape(self) -> torch.Size: + return self.unet.input_shape diff --git a/ludwig/features/feature_registries.py b/ludwig/features/feature_registries.py index 69d8d687dcc..2a738a2979f 100644 --- a/ludwig/features/feature_registries.py +++ b/ludwig/features/feature_registries.py @@ -43,7 +43,7 @@ ) from ludwig.features.date_feature import DateFeatureMixin, DateInputFeature from ludwig.features.h3_feature import H3FeatureMixin, H3InputFeature -from ludwig.features.image_feature import ImageFeatureMixin, ImageInputFeature +from ludwig.features.image_feature import ImageFeatureMixin, ImageInputFeature, ImageOutputFeature from ludwig.features.number_feature import NumberFeatureMixin, NumberInputFeature, NumberOutputFeature from ludwig.features.sequence_feature import SequenceFeatureMixin, SequenceInputFeature, SequenceOutputFeature from ludwig.features.set_feature import SetFeatureMixin, SetInputFeature, SetOutputFeature @@ -108,6 +108,7 @@ def get_output_type_registry() -> Dict: TIMESERIES: TimeseriesOutputFeature, VECTOR: VectorOutputFeature, CATEGORY_DISTRIBUTION: CategoryDistributionOutputFeature, + IMAGE: ImageOutputFeature, } diff --git a/ludwig/features/image_feature.py b/ludwig/features/image_feature.py index b6c8548fb33..81c389d0935 100644 --- a/ludwig/features/image_feature.py +++ b/ludwig/features/image_feature.py @@ -37,9 +37,12 @@ INFER_IMAGE_DIMENSIONS, INFER_IMAGE_MAX_HEIGHT, INFER_IMAGE_MAX_WIDTH, + INFER_IMAGE_NUM_CLASSES, INFER_IMAGE_SAMPLE_SIZE, + LOGITS, NAME, NUM_CHANNELS, + PREDICTIONS, PREPROCESSING, PROC_COLUMN, REQUIRES_EQUAL_DIMENSIONS, @@ -49,8 +52,9 @@ WIDTH, ) from ludwig.data.cache.types import wrap +from ludwig.encoders.base import Encoder from ludwig.encoders.image.torchvision import TVModelVariant -from ludwig.features.base_feature import BaseFeatureMixin, InputFeature +from ludwig.features.base_feature import BaseFeatureMixin, InputFeature, OutputFeature, PredictModule from ludwig.schema.features.augmentation.base import BaseAugmentationConfig from ludwig.schema.features.augmentation.image import ( AutoAugmentationConfig, @@ -61,16 +65,25 @@ RandomRotateConfig, RandomVerticalFlipConfig, ) -from ludwig.schema.features.image_feature import ImageInputFeatureConfig -from ludwig.types import FeatureMetadataDict, ModelConfigDict, PreprocessingConfigDict, TrainingSetMetadataDict +from ludwig.schema.features.image_feature import ImageInputFeatureConfig, ImageOutputFeatureConfig +from ludwig.types import ( + FeatureMetadataDict, + FeaturePostProcessingOutputDict, + ModelConfigDict, + PreprocessingConfigDict, + TrainingSetMetadataDict, +) +from ludwig.utils import output_feature_utils from ludwig.utils.augmentation_utils import get_augmentation_op, register_augmentation_op from ludwig.utils.data_utils import get_abs_path from ludwig.utils.dataframe_utils import is_dask_series_or_df from ludwig.utils.fs_utils import has_remote_protocol, upload_h5 from ludwig.utils.image_utils import ( + get_class_mask_from_image, get_gray_default_image, + get_image_from_class_mask, + get_unique_channels, grayscale, - is_torchvision_encoder, num_channels_in_image, read_image_from_bytes_obj, read_image_from_path, @@ -308,6 +321,13 @@ def _get_torchvision_parameters(model_type: str, model_variant: str) -> TVModelV return torchvision_model_registry.get(model_type).get(model_variant) +def is_torchvision_encoder(encoder_obj: Encoder) -> bool: + # TODO(travis): do this through an interface rather than conditional logic + from ludwig.encoders.image.torchvision import TVBaseEncoder + + return isinstance(encoder_obj, TVBaseEncoder) + + class _ImagePreprocessing(torch.nn.Module): """Torchscript-enabled version of preprocessing done by ImageFeatureMixin.add_feature_data.""" @@ -325,10 +345,12 @@ def __init__( self.height = transform_metadata.height self.width = transform_metadata.width self.num_channels = transform_metadata.num_channels + self.channel_class_map = torch.Tensor([]) else: self.height = metadata["preprocessing"]["height"] self.width = metadata["preprocessing"]["width"] self.num_channels = metadata["preprocessing"]["num_channels"] + self.channel_class_map = torch.ByteTensor(metadata["preprocessing"]["channel_class_map"]) def forward(self, v: TorchscriptPreprocessingInput) -> torch.Tensor: """Takes a list of images and adjusts the size and number of channels as specified in the metadata. @@ -378,11 +400,40 @@ def forward(self, v: TorchscriptPreprocessingInput) -> torch.Tensor: f"{self.num_channels}, but imgs.shape[1] = {num_channels}" ) - imgs_stacked = imgs_stacked.type(torch.float32) / 255 + # Create class-masked images if required + if self.channel_class_map.shape[0]: + masks = [] + for img in imgs_stacked: + mask = get_class_mask_from_image(self.channel_class_map, img) + masks.append(mask) + imgs_stacked = torch.stack(masks) + else: + imgs_stacked = imgs_stacked.type(torch.float32) / 255 return imgs_stacked +class _ImagePostprocessing(torch.nn.Module): + def __init__(self): + super().__init__() + self.logits_key = LOGITS + self.predictions_key = PREDICTIONS + + def forward(self, preds: Dict[str, torch.Tensor], feature_name: str) -> FeaturePostProcessingOutputDict: + predictions = output_feature_utils.get_output_feature_tensor(preds, feature_name, self.predictions_key) + logits = output_feature_utils.get_output_feature_tensor(preds, feature_name, self.logits_key) + + return {self.predictions_key: predictions, self.logits_key: logits} + + +class _ImagePredict(PredictModule): + def forward(self, inputs: Dict[str, torch.Tensor], feature_name: str) -> Dict[str, torch.Tensor]: + predictions = output_feature_utils.get_output_feature_tensor(inputs, feature_name, self.predictions_key) + logits = output_feature_utils.get_output_feature_tensor(inputs, feature_name, self.logits_key) + + return {self.predictions_key: predictions, self.logits_key: logits} + + class ImageFeatureMixin(BaseFeatureMixin): @staticmethod def type(): @@ -412,6 +463,7 @@ def _read_image_if_bytes_obj_and_resize( resize_method: str, user_specified_num_channels: bool, standardize_image: str, + channel_class_map: torch.Tensor, ) -> Optional[np.ndarray]: """ :param img_entry Union[bytes, torch.Tensor, np.ndarray, str]: if str file path to the @@ -423,6 +475,7 @@ def _read_image_if_bytes_obj_and_resize( :param num_channels: expected number of channels in the first image :param user_specified_num_channels: did the user specify num channels? :param standardize_image: specifies whether to standarize image with imagenet1k specifications + :param channel_class_map: A tensor mapping channel values to classes, where dim=0 is the class :return: image object as a numpy array Helper method to read and resize an image according to model definition. @@ -497,11 +550,15 @@ def _read_image_if_bytes_obj_and_resize( "#image-features-preprocessing".format([img_height, img_width, num_channels], img.shape) ) - # casting and rescaling - img = img.type(torch.float32) / 255 + # Create class-masked image if required + if channel_class_map.shape[0]: + img = get_class_mask_from_image(channel_class_map, img) + else: + # casting and rescaling + img = img.type(torch.float32) / 255 - if standardize_image == IMAGENET1K: - img = normalize(img, mean=IMAGENET1K_MEAN, std=IMAGENET1K_STD) + if standardize_image == IMAGENET1K: + img = normalize(img, mean=IMAGENET1K_MEAN, std=IMAGENET1K_STD) return img.numpy() @@ -630,6 +687,41 @@ def _infer_number_of_channels(image_sample: List[torch.Tensor]): ) return num_channels + @staticmethod + def _infer_image_num_classes( + image_sample: List[torch.Tensor], + num_channels: int, + num_classes: int, + ) -> torch.Tensor: + """Infers the number of channel classes from a group of images (for image segmentation). The returned + tensor contains the channel value for each class, where dim=0 is the class. + + Args: + image_sample: Sample of images to use to infer image size. Must be formatted as [channels, height, width]. + num_channels: Expected number of channels + num_classes: Expected number of channel classes or None + + Return: + channel_class_map: A tensor mapping channel values to classes, where dim=0 is the class. + """ + n_images = len(image_sample) + logger.info(f"Inferring num_classes from the first {n_images} images.") + channel_class_map = get_unique_channels(image_sample, num_channels, num_classes) + + inferred_num_classes = channel_class_map.shape[0] + if num_classes: + if num_classes < inferred_num_classes: + raise ValueError( + f"Images inferred num classes {inferred_num_classes} exceeds `num_classes` {num_classes}." + ) + elif num_classes > inferred_num_classes: + logger.warning( + "Images inferred num classes {} does not match `num_classes` {}. " + "Using inferred num classes {}.".format(inferred_num_classes, num_classes, inferred_num_classes) + ) + + return channel_class_map + @staticmethod def _finalize_preprocessing_parameters( preprocessing_parameters: dict, @@ -744,6 +836,13 @@ def _finalize_preprocessing_parameters( ) standardize_image = None + if preprocessing_parameters[INFER_IMAGE_NUM_CLASSES] or preprocessing_parameters["num_classes"]: + channel_class_map = ImageFeatureMixin._infer_image_num_classes( + sample, num_channels, preprocessing_parameters["num_classes"] + ) + else: + channel_class_map = torch.Tensor([]) + return ( should_resize, width, @@ -752,6 +851,7 @@ def _finalize_preprocessing_parameters( user_specified_num_channels, average_file_size, standardize_image, + channel_class_map, ) @staticmethod @@ -768,7 +868,7 @@ def add_feature_data( name = feature_config[NAME] column = input_df[feature_config[COLUMN]] - encoder_type = feature_config[ENCODER][TYPE] + encoder_type = feature_config[ENCODER][TYPE] if ENCODER in feature_config.keys() else None src_path = None if SRC in metadata: @@ -779,8 +879,8 @@ def add_feature_data( ) # determine if specified encoder is a torchvision model - model_type = feature_config[ENCODER].get("type", None) - model_variant = feature_config[ENCODER].get("model_variant") + model_type = feature_config[ENCODER].get("type", None) if ENCODER in feature_config.keys() else None + model_variant = feature_config[ENCODER].get("model_variant") if ENCODER in feature_config.keys() else None if model_variant: torchvision_parameters = _get_torchvision_parameters(model_type, model_variant) else: @@ -815,6 +915,7 @@ def add_feature_data( height = transform_metadata.height width = transform_metadata.width num_channels = transform_metadata.num_channels + channel_class_map = torch.Tensor([]) else: # torchvision_parameters is None # perform Ludwig specified transformations @@ -826,6 +927,7 @@ def add_feature_data( user_specified_num_channels, average_file_size, standardize_image, + channel_class_map, ) = ImageFeatureMixin._finalize_preprocessing_parameters( preprocessing_parameters, encoder_type, abs_path_column ) @@ -833,6 +935,8 @@ def add_feature_data( metadata[name][PREPROCESSING]["height"] = height metadata[name][PREPROCESSING]["width"] = width metadata[name][PREPROCESSING]["num_channels"] = num_channels + metadata[name][PREPROCESSING]["num_classes"] = channel_class_map.shape[0] + metadata[name][PREPROCESSING]["channel_class_map"] = channel_class_map.tolist() read_image_if_bytes_obj_and_resize = partial( ImageFeatureMixin._read_image_if_bytes_obj_and_resize, @@ -843,11 +947,16 @@ def add_feature_data( resize_method=preprocessing_parameters["resize_method"], user_specified_num_channels=user_specified_num_channels, standardize_image=standardize_image, + channel_class_map=channel_class_map, ) # TODO: alternatively use get_average_image() for unreachable images - default_image = get_gray_default_image(num_channels, height, width) - metadata[name]["reshape"] = (num_channels, height, width) + if channel_class_map.shape[0]: + default_image = get_gray_default_image(1, height, width).squeeze(0) + metadata[name]["reshape"] = (height, width) + else: + default_image = get_gray_default_image(num_channels, height, width) + metadata[name]["reshape"] = (num_channels, height, width) in_memory = feature_config[PREPROCESSING]["in_memory"] if in_memory or skip_save_processed_input: @@ -989,3 +1098,79 @@ def create_preproc_module(metadata: Dict[str, Any]) -> torch.nn.Module: def get_augmentation_pipeline(self): return self.augmentation_pipeline + + +class ImageOutputFeature(ImageFeatureMixin, OutputFeature): + def __init__( + self, + output_feature_config: Union[ImageOutputFeatureConfig, Dict], + output_features: Dict[str, OutputFeature], + **kwargs, + ): + super().__init__(output_feature_config, output_features, **kwargs) + self.decoder_obj = self.initialize_decoder(output_feature_config.decoder) + self._setup_loss() + self._setup_metrics() + + def logits(self, inputs: Dict[str, torch.Tensor], target=None, **kwargs): + return self.decoder_obj(inputs, target=target) + + def metric_kwargs(self): + return dict(num_outputs=self.output_shape[0]) + + def create_predict_module(self) -> PredictModule: + return _ImagePredict() + + def get_prediction_set(self): + return self.decoder_obj.get_prediction_set() + + @classmethod + def get_output_dtype(cls): + return torch.float32 + + @property + def output_shape(self) -> torch.Size: + return self.decoder_obj.output_shape + + @property + def input_shape(self) -> torch.Size: + return self.decoder_obj.input_shape + + @staticmethod + def update_config_with_metadata(feature_config, feature_metadata, *args, **kwargs): + for key in ["height", "width", "num_channels", "num_classes", "standardize_image"]: + if hasattr(feature_config.decoder, key): + setattr(feature_config.decoder, key, feature_metadata[PREPROCESSING][key]) + + @staticmethod + def calculate_overall_stats(predictions, targets, metadata): + # no overall stats, just return empty dictionary + return {} + + def postprocess_predictions( + self, + result, + metadata, + ): + predictions_col = f"{self.feature_name}_{PREDICTIONS}" + + if predictions_col in result: + channel_class_map = torch.ByteTensor(metadata[PREPROCESSING]["channel_class_map"]) + + if channel_class_map.shape[0]: + + def class_mask2img(row): + pred = row[predictions_col] + return get_image_from_class_mask(channel_class_map, pred) + + result[predictions_col] = result.apply(class_mask2img, axis=1) + + return result + + @staticmethod + def create_postproc_module(metadata: TrainingSetMetadataDict) -> torch.nn.Module: + return _ImagePostprocessing(metadata) + + @staticmethod + def get_schema_cls(): + return ImageOutputFeatureConfig diff --git a/ludwig/modules/convolutional_modules.py b/ludwig/modules/convolutional_modules.py index c9f4b69ee23..198d9d1eae9 100644 --- a/ludwig/modules/convolutional_modules.py +++ b/ludwig/modules/convolutional_modules.py @@ -1309,3 +1309,193 @@ def get_resnet_block_sizes(resnet_size): resnet_size, resnet_choices.keys() ) raise ValueError(err) + + +class UNetDoubleConvLayer(LudwigModule): + def __init__( + self, + img_height: int, + img_width: int, + in_channels: int, + out_channels: int, + norm: str = None, + ): + """Two Conv2d layers, each followed by a ReLU, used for U-Net. + + Args: + img_height: the input image height + img_width: the input image width + in_channels: the number of input channels + out_channels: the number of output channels + norm: the normalization to be applied + """ + super().__init__() + + self.layers = nn.ModuleList() + + self.layers.append(nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)) + if norm == "batch": + self.layers.append(nn.BatchNorm2d(out_channels)) + self.layers.append(nn.ReLU()) + + self.layers.append(nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)) + if norm == "batch": + self.layers.append(nn.BatchNorm2d(out_channels)) + self.layers.append(nn.ReLU()) + + self._input_shape = (in_channels, img_height, img_width) + self._output_shape = (out_channels, img_height, img_width) + + def forward(self, inputs): + hidden = inputs + + for layer in self.layers: + hidden = layer(hidden) + + return hidden + + @property + def output_shape(self) -> torch.Size: + return torch.Size(self._output_shape) + + @property + def input_shape(self) -> torch.Size: + return torch.Size(self._input_shape) + + +class UNetDownStack(LudwigModule): + def __init__( + self, + img_height: int, + img_width: int, + in_channels: int, + norm: str = None, + stack_depth: int = 4, + ): + """Creates the contracting downsampling path of a U-Net stack. + + Implements + U-Net: Convolutional Networks for Biomedical Image Segmentation + https://arxiv.org/abs/1505.04597 + by Olaf Ronneberger, Philipp Fischer, Thomas Brox, May 2015. + + Args: + img_height: the input image height + img_width: the input image width + in_channels: the number of input channels + norm: the normalization to be applied + stack_depth: the depth of the unet stack + """ + super().__init__() + + self.conv_layers = nn.ModuleList() + self.down_layers = nn.ModuleList() + + height = img_height + width = img_width + in_c = in_channels + out_c = 64 + + self._input_shape = (in_c, height, width) + + for i in range(stack_depth): + self.conv_layers.append(UNetDoubleConvLayer(height, width, in_c, out_c, norm)) + in_c = out_c + out_c = out_c * 2 + + self.down_layers.append(nn.MaxPool2d(kernel_size=2, stride=2)) + height = height // 2 + width = width // 2 + + self.bottleneck = UNetDoubleConvLayer(height, width, in_c, out_c, norm) + + self._output_shape = (out_c, height, width) + + def forward(self, inputs): + skips = [] # skip connections + hidden = inputs + + for conv_layer, down_layer in zip(self.conv_layers, self.down_layers): + hidden = conv_layer(hidden) + skips.append(hidden) + hidden = down_layer(hidden) + + hidden = self.bottleneck(hidden) + return hidden, skips + + @property + def output_shape(self) -> torch.Size: + return torch.Size(self._output_shape) + + @property + def input_shape(self) -> torch.Size: + return torch.Size(self._input_shape) + + +class UNetUpStack(LudwigModule): + def __init__( + self, + img_height: int, + img_width: int, + out_channels: int, + norm: str = None, + stack_depth: int = 4, + ): + """Creates the expansive upsampling path of a U-Net stack. + + Implements + U-Net: Convolutional Networks for Biomedical Image Segmentation + https://arxiv.org/abs/1505.04597 + by Olaf Ronneberger, Philipp Fischer, Thomas Brox, May 2015. + + Args: + img_height: the output image height + img_width: the output image width + out_channels: the number of output classes + norm: the normalization to be applied + stack_depth: the depth of the unet stack + """ + super().__init__() + + self.conv_layers = nn.ModuleList() + self.up_layers = nn.ModuleList() + + height = img_height >> stack_depth + width = img_width >> stack_depth + in_c = 64 << stack_depth + out_c = in_c // 2 + + self._input_shape = (in_c, height, width) + + for i in range(stack_depth): + self.up_layers.append(nn.ConvTranspose2d(in_c, out_c, kernel_size=2, stride=2)) + height = height * 2 + width = width * 2 + + self.conv_layers.append(UNetDoubleConvLayer(height, width, out_c * 2, out_c, norm)) + in_c = out_c + out_c = out_c // 2 + + self.last_conv = nn.Conv2d(in_c, out_channels, kernel_size=1, padding=0) + + self._output_shape = (out_channels, img_height, img_width) + + def forward(self, inputs, skips): + hidden = inputs + + for conv_layer, up_layer in zip(self.conv_layers, self.up_layers): + hidden = up_layer(hidden) + skip = skips.pop() + hidden = torch.cat([hidden, skip], axis=1) + hidden = conv_layer(hidden) + + hidden = self.last_conv(hidden) + return hidden + + @property + def output_shape(self) -> torch.Size: + return torch.Size(self._output_shape) + + @property + def input_shape(self) -> torch.Size: + return torch.Size(self._input_shape) diff --git a/ludwig/modules/loss_modules.py b/ludwig/modules/loss_modules.py index c86c385f1ce..a239def352d 100644 --- a/ludwig/modules/loss_modules.py +++ b/ludwig/modules/loss_modules.py @@ -167,11 +167,15 @@ def forward(self, preds: Tensor, target: Tensor) -> Tensor: """ Params: preds: Tensor of shape [batch x num_classes] + or shape [batch x num_classes x H x W] target: Tensor of shape [batch], where each element is integral + between 0 and num_classes. + or shape [batch x H x W], where each element is integral between 0 and num_classes. """ - if len(target.shape) == 1: + if len(target.shape) == 1 or len(target.shape) == 3: # Assumes we are providing the target as a single class, rather than a distribution + # The target shape can be a 3D tensor [batch x H x W], for image segmentation target = target.long() return self.loss_fn(preds, target) diff --git a/ludwig/modules/metric_modules.py b/ludwig/modules/metric_modules.py index 84f06d30d46..2a6de164aac 100644 --- a/ludwig/modules/metric_modules.py +++ b/ludwig/modules/metric_modules.py @@ -48,6 +48,7 @@ HITS_AT_K, HUBER, IGNORE_INDEX_TOKEN_ID, + IMAGE, JACCARD, LOGITS, LOSS, @@ -307,7 +308,7 @@ def get_current_value(self, preds: Tensor, target: Tensor) -> Tensor: return self.loss_function(preds, target) -@register_metric("softmax_cross_entropy", [CATEGORY, CATEGORY_DISTRIBUTION], MINIMIZE, LOGITS) +@register_metric("softmax_cross_entropy", [CATEGORY, CATEGORY_DISTRIBUTION, IMAGE], MINIMIZE, LOGITS) class SoftmaxCrossEntropyMetric(LossMetric): def __init__(self, config: SoftmaxCrossEntropyLossConfig, **kwargs): super().__init__() diff --git a/ludwig/schema/decoders/__init__.py b/ludwig/schema/decoders/__init__.py index bcc84aeeeb6..9663505403d 100644 --- a/ludwig/schema/decoders/__init__.py +++ b/ludwig/schema/decoders/__init__.py @@ -1,4 +1,5 @@ # Register all decoders import ludwig.schema.decoders.base +import ludwig.schema.decoders.image_decoders # noqa import ludwig.schema.decoders.llm_decoders # noqa import ludwig.schema.decoders.sequence_decoders # noqa diff --git a/ludwig/schema/decoders/image_decoders.py b/ludwig/schema/decoders/image_decoders.py new file mode 100644 index 00000000000..4ea0a11286b --- /dev/null +++ b/ludwig/schema/decoders/image_decoders.py @@ -0,0 +1,74 @@ +from typing import Optional, TYPE_CHECKING + +from ludwig.api_annotations import DeveloperAPI +from ludwig.constants import IMAGE, MODEL_ECD +from ludwig.schema import utils as schema_utils +from ludwig.schema.decoders.base import BaseDecoderConfig +from ludwig.schema.decoders.utils import register_decoder_config +from ludwig.schema.metadata import DECODER_METADATA +from ludwig.schema.utils import ludwig_dataclass + +if TYPE_CHECKING: + from ludwig.schema.features.preprocessing.image import ImagePreprocessingConfig + + +class ImageDecoderConfig(BaseDecoderConfig): + def set_fixed_preprocessing_params(self, model_type: str, preprocessing: "ImagePreprocessingConfig"): + preprocessing.requires_equal_dimensions = False + preprocessing.height = None + preprocessing.width = None + + +@DeveloperAPI +@register_decoder_config("unet", [IMAGE], model_types=[MODEL_ECD]) +@ludwig_dataclass +class UNetDecoderConfig(ImageDecoderConfig): + @staticmethod + def module_name(): + return "UNetDecoder" + + type: str = schema_utils.ProtectedString( + "unet", + description=DECODER_METADATA["UNetDecoder"]["type"].long_description, + ) + + input_size: int = schema_utils.PositiveInteger( + default=1024, + description="Size of the input to the decoder.", + parameter_metadata=DECODER_METADATA["UNetDecoder"]["input_size"], + ) + + height: int = schema_utils.NonNegativeInteger( + default=None, + allow_none=True, + description="Height of the output image.", + parameter_metadata=DECODER_METADATA["UNetDecoder"]["height"], + ) + + width: int = schema_utils.NonNegativeInteger( + default=None, + allow_none=True, + description="Width of the output image.", + parameter_metadata=DECODER_METADATA["UNetDecoder"]["width"], + ) + + num_channels: Optional[int] = schema_utils.NonNegativeInteger( + default=None, + allow_none=True, + description="Number of channels in the output image. ", + parameter_metadata=DECODER_METADATA["UNetDecoder"]["num_channels"], + ) + + conv_norm: Optional[str] = schema_utils.StringOptions( + ["batch"], + default="batch", + allow_none=True, + description="This is the default norm that will be used for each double conv layer." "It can be null or batch.", + parameter_metadata=DECODER_METADATA["UNetDecoder"]["conv_norm"], + ) + + num_classes: Optional[int] = schema_utils.NonNegativeInteger( + default=None, + allow_none=True, + description="Number of classes to predict in the output. ", + ) diff --git a/ludwig/schema/encoders/image/base.py b/ludwig/schema/encoders/image/base.py index 44f2a893096..c0feeecb3b8 100644 --- a/ludwig/schema/encoders/image/base.py +++ b/ludwig/schema/encoders/image/base.py @@ -718,3 +718,46 @@ def required_height(cls) -> Optional[int]: def is_pretrained(self) -> bool: return self.use_pretrained + + +@DeveloperAPI +@register_encoder_config("unet", IMAGE) +@ludwig_dataclass +class UNetEncoderConfig(ImageEncoderConfig): + @staticmethod + def module_name(): + return "UNetEncoder" + + type: str = schema_utils.ProtectedString( + "unet", + description=ENCODER_METADATA["UNetEncoder"]["type"].long_description, + ) + + height: int = schema_utils.NonNegativeInteger( + default=None, + allow_none=True, + description="Height of the input image.", + parameter_metadata=ENCODER_METADATA["UNetEncoder"]["height"], + ) + + width: int = schema_utils.NonNegativeInteger( + default=None, + allow_none=True, + description="Width of the input image.", + parameter_metadata=ENCODER_METADATA["UNetEncoder"]["width"], + ) + + num_channels: Optional[int] = schema_utils.NonNegativeInteger( + default=None, + allow_none=True, + description="Number of channels in the input image. ", + parameter_metadata=ENCODER_METADATA["UNetEncoder"]["num_channels"], + ) + + conv_norm: Optional[str] = schema_utils.StringOptions( + ["batch"], + default="batch", + allow_none=True, + description="This is the default norm that will be used for each double conv layer." "It can be null or batch.", + parameter_metadata=ENCODER_METADATA["UNetEncoder"]["conv_norm"], + ) diff --git a/ludwig/schema/features/image_feature.py b/ludwig/schema/features/image_feature.py index cf66c5b1965..9322ee253cb 100644 --- a/ludwig/schema/features/image_feature.py +++ b/ludwig/schema/features/image_feature.py @@ -1,17 +1,29 @@ from typing import List from ludwig.api_annotations import DeveloperAPI -from ludwig.constants import IMAGE, MODEL_ECD +from ludwig.constants import IMAGE, LOSS, MODEL_ECD, SOFTMAX_CROSS_ENTROPY from ludwig.schema import utils as schema_utils +from ludwig.schema.decoders.base import BaseDecoderConfig +from ludwig.schema.decoders.utils import DecoderDataclassField from ludwig.schema.encoders.base import BaseEncoderConfig from ludwig.schema.encoders.utils import EncoderDataclassField from ludwig.schema.features.augmentation.base import BaseAugmentationConfig from ludwig.schema.features.augmentation.image import RandomHorizontalFlipConfig, RandomRotateConfig from ludwig.schema.features.augmentation.utils import AugmentationDataclassField -from ludwig.schema.features.base import BaseInputFeatureConfig +from ludwig.schema.features.base import BaseInputFeatureConfig, BaseOutputFeatureConfig +from ludwig.schema.features.loss.loss import BaseLossConfig +from ludwig.schema.features.loss.utils import LossDataclassField from ludwig.schema.features.preprocessing.base import BasePreprocessingConfig from ludwig.schema.features.preprocessing.utils import PreprocessingDataclassField -from ludwig.schema.features.utils import ecd_defaults_config_registry, ecd_input_config_registry, input_mixin_registry +from ludwig.schema.features.utils import ( + ecd_defaults_config_registry, + ecd_input_config_registry, + ecd_output_config_registry, + input_mixin_registry, + output_mixin_registry, +) +from ludwig.schema.metadata import FEATURE_METADATA +from ludwig.schema.metadata.parameter_metadata import INTERNAL_ONLY from ludwig.schema.utils import BaseMarshmallowConfig, ludwig_dataclass # Augmentation operations when augmentation is set to True @@ -56,3 +68,66 @@ class ImageInputFeatureConfig(ImageInputFeatureConfigMixin, BaseInputFeatureConf """ImageInputFeatureConfig is a dataclass that configures the parameters used for an image input feature.""" type: str = schema_utils.ProtectedString(IMAGE) + + +@DeveloperAPI +@output_mixin_registry.register(IMAGE) +@ludwig_dataclass +class ImageOutputFeatureConfigMixin(BaseMarshmallowConfig): + """ImageOutputFeatureConfigMixin is a dataclass that configures the parameters used in both the image output + feature and the image global defaults section of the Ludwig Config.""" + + decoder: BaseDecoderConfig = DecoderDataclassField( + MODEL_ECD, + feature_type=IMAGE, + default="unet", + ) + + loss: BaseLossConfig = LossDataclassField( + feature_type=IMAGE, + default=SOFTMAX_CROSS_ENTROPY, + ) + + +@DeveloperAPI +@ecd_output_config_registry.register(IMAGE) +@ludwig_dataclass +class ImageOutputFeatureConfig(ImageOutputFeatureConfigMixin, BaseOutputFeatureConfig): + """ImageOutputFeatureConfig is a dataclass that configures the parameters used for an image output feature.""" + + type: str = schema_utils.ProtectedString(IMAGE) + + dependencies: list = schema_utils.List( + default=[], + description="List of input features that this feature depends on.", + parameter_metadata=FEATURE_METADATA[IMAGE]["dependencies"], + ) + + default_validation_metric: str = schema_utils.StringOptions( + [LOSS], + default=LOSS, + description="Internal only use parameter: default validation metric for image output feature.", + parameter_metadata=INTERNAL_ONLY, + ) + + preprocessing: BasePreprocessingConfig = PreprocessingDataclassField(feature_type="image_output") + + reduce_dependencies: str = schema_utils.ReductionOptions( + default=None, + description="How to reduce the dependencies of the output feature.", + parameter_metadata=FEATURE_METADATA[IMAGE]["reduce_dependencies"], + ) + + reduce_input: str = schema_utils.ReductionOptions( + default=None, + description="How to reduce an input that is not a vector, but a matrix or a higher order tensor, on the first " + "dimension (second if you count the batch dimension)", + parameter_metadata=FEATURE_METADATA[IMAGE]["reduce_input"], + ) + + +@DeveloperAPI +@ecd_defaults_config_registry.register(IMAGE) +@ludwig_dataclass +class ImageDefaultsConfig(ImageInputFeatureConfigMixin, ImageOutputFeatureConfigMixin): + pass diff --git a/ludwig/schema/features/loss/loss.py b/ludwig/schema/features/loss/loss.py index f4ec1472e9d..2dffbe34293 100644 --- a/ludwig/schema/features/loss/loss.py +++ b/ludwig/schema/features/loss/loss.py @@ -7,6 +7,7 @@ CATEGORY, CORN, HUBER, + IMAGE, MEAN_ABSOLUTE_ERROR, MEAN_ABSOLUTE_PERCENTAGE_ERROR, MEAN_SQUARED_ERROR, @@ -251,7 +252,7 @@ def name(self) -> str: @DeveloperAPI -@register_loss([CATEGORY, VECTOR]) +@register_loss([CATEGORY, VECTOR, IMAGE]) @ludwig_dataclass class SoftmaxCrossEntropyLossConfig(BaseLossConfig): type: str = schema_utils.ProtectedString( diff --git a/ludwig/schema/features/preprocessing/image.py b/ludwig/schema/features/preprocessing/image.py index 95bfacbfadd..99a59169d66 100644 --- a/ludwig/schema/features/preprocessing/image.py +++ b/ludwig/schema/features/preprocessing/image.py @@ -1,7 +1,7 @@ from typing import Optional, Union from ludwig.api_annotations import DeveloperAPI -from ludwig.constants import BFILL, IMAGE, IMAGENET1K, MISSING_VALUE_STRATEGY_OPTIONS, PREPROCESSING +from ludwig.constants import BFILL, DROP_ROW, IMAGE, IMAGENET1K, MISSING_VALUE_STRATEGY_OPTIONS, PREPROCESSING from ludwig.schema import utils as schema_utils from ludwig.schema.features.preprocessing.base import BasePreprocessingConfig from ludwig.schema.features.preprocessing.utils import register_preprocessor @@ -139,3 +139,32 @@ class ImagePreprocessingConfig(BasePreprocessingConfig): description="If true, then width and height must be equal.", parameter_metadata=FEATURE_METADATA[IMAGE][PREPROCESSING]["requires_equal_dimensions"], ) + + num_classes: int = schema_utils.PositiveInteger( + default=None, + allow_none=True, + description="Number of channel classes in the images. If specified, this value will be validated " + "against the inferred number of classes. ", + parameter_metadata=FEATURE_METADATA[IMAGE][PREPROCESSING]["num_classes"], + ) + + infer_image_num_classes: bool = schema_utils.Boolean( + default=False, + description="If true, then the number of channel classes in the dataset will be inferred from a sample of " + "the first image in the dataset. Each unique channel value will be mapped to a class and preprocessing will " + "create a masked image based on the channel classes. ", + parameter_metadata=FEATURE_METADATA[IMAGE][PREPROCESSING]["infer_image_num_classes"], + ) + + +@DeveloperAPI +@register_preprocessor("image_output") +@ludwig_dataclass +class ImageOutputPreprocessingConfig(ImagePreprocessingConfig): + missing_value_strategy: str = schema_utils.StringOptions( + MISSING_VALUE_STRATEGY_OPTIONS, + default=DROP_ROW, + allow_none=False, + description="What strategy to follow when there's a missing value in an image column", + parameter_metadata=FEATURE_METADATA[IMAGE][PREPROCESSING]["missing_value_strategy"], + ) diff --git a/ludwig/schema/metadata/configs/decoders.yaml b/ludwig/schema/metadata/configs/decoders.yaml index 1b1992294d5..125767239e0 100644 --- a/ludwig/schema/metadata/configs/decoders.yaml +++ b/ludwig/schema/metadata/configs/decoders.yaml @@ -497,3 +497,44 @@ SequenceTaggerDecoder: vocab_size: ui_display_name: Not displayed internal_only: true +UNetDecoder: + type: + short_description: The UNet decoder convolutional and up-conv layers + long_description: + Stacks of two 2D convolutional layers with optional normalization + and relu activation, preceeded by an up-conv layer in all but the + final level of the decoder. + compute_tier: 1 + conv_norm: + expected_impact: 2 + ui_display_name: Convolutional Normalization + height: + default_value_reasoning: + Computed internally, automatically, based on image + data preprocessing. + internal_only: true + ui_display_name: NOT DISPLAYED + input_size: + other_information: Internal Only + internal_only: true + related_parameters: + - "No" + ui_display_name: Not Displayed + num_channels: + default_value_reasoning: + Computed internally, automatically, based on image + data preprocessing. + internal_only: true + ui_display_name: NOT DISPLAYED + num_classes: + default_value_reasoning: + Computed internally, automatically, based on image + data preprocessing. + internal_only: true + ui_display_name: NOT DISPLAYED + width: + default_value_reasoning: + Computed internally, automatically, based on image + data preprocessing. + internal_only: true + ui_display_name: NOT DISPLAYED diff --git a/ludwig/schema/metadata/configs/encoders.yaml b/ludwig/schema/metadata/configs/encoders.yaml index e5853fe57d4..271f8f8322e 100644 --- a/ludwig/schema/metadata/configs/encoders.yaml +++ b/ludwig/schema/metadata/configs/encoders.yaml @@ -9572,3 +9572,31 @@ conv_params: filter_size: ui_display_name: null expected_impact: 2 +UNetEncoder: + type: + short_description: The UNet encoder convolutional and max pool layers + long_description: + Stacks of two 2D convolutional layers with optional normalization + and relu activation, followed by a max pool layer in all but the + final level of the encoder. + compute_tier: 1 + conv_norm: + expected_impact: 2 + ui_display_name: Convolutional Normalization + height: + default_value_reasoning: + Computed internally, automatically, based on image + data preprocessing. + internal_only: true + ui_display_name: NOT DISPLAYED + num_channels: + default_value_reasoning: + Computed internally, automatically, based on image + data preprocessing. + ui_display_name: NOT DISPLAYED + width: + default_value_reasoning: + Computed internally, automatically, based on image + data preprocessing. + internal_only: true + ui_display_name: NOT DISPLAYED diff --git a/ludwig/schema/metadata/configs/features.yaml b/ludwig/schema/metadata/configs/features.yaml index bfe8bb27b5d..15007a4573f 100644 --- a/ludwig/schema/metadata/configs/features.yaml +++ b/ludwig/schema/metadata/configs/features.yaml @@ -437,6 +437,9 @@ image: infer_image_sample_size: ui_display_name: null expected_impact: 1 + infer_image_num_classes: + ui_display_name: null + expected_impact: 1 missing_value_strategy: default_value_reasoning: The default `fill_with_const` replaces missing @@ -455,6 +458,9 @@ image: num_channels: ui_display_name: null expected_impact: 2 + num_classes: + ui_display_name: null + expected_impact: 2 num_processes: ui_display_name: null expected_impact: 2 @@ -483,6 +489,12 @@ image: requires_equal_dimensions: ui_display_name: null expected_impact: 1 + dependencies: + expected_impact: 1 + reduce_dependencies: + expected_impact: 1 + reduce_input: + expected_impact: 1 number: preprocessing: computed_fill_value: diff --git a/ludwig/utils/image_utils.py b/ludwig/utils/image_utils.py index 253f0b49c6a..3470a9a9a6c 100644 --- a/ludwig/utils/image_utils.py +++ b/ludwig/utils/image_utils.py @@ -28,8 +28,7 @@ from torchvision.models._api import WeightsEnum from ludwig.api_annotations import DeveloperAPI -from ludwig.constants import CROP_OR_PAD, INTERPOLATE -from ludwig.encoders.base import Encoder +from ludwig.constants import CROP_OR_PAD, IMAGE_MAX_CLASSES, INTERPOLATE from ludwig.utils.fs_utils import get_bytes_obj_from_path from ludwig.utils.registry import Registry @@ -97,14 +96,6 @@ def is_image_score(path): return int(isinstance(path, str) and path.lower().endswith(IMAGE_EXTENSIONS)) -@DeveloperAPI -def is_torchvision_encoder(encoder_obj: Encoder) -> bool: - # TODO(travis): do this through an interface rather than conditional logic - from ludwig.encoders.image.torchvision import TVBaseEncoder - - return isinstance(encoder_obj, TVBaseEncoder) - - @DeveloperAPI def get_image_read_mode_from_num_channels(num_channels: int) -> ImageReadMode: """Returns the torchvision.io.ImageReadMode corresponding to the number of channels. @@ -318,6 +309,126 @@ def num_channels_in_image(img: torch.Tensor): return img.shape[0] +@DeveloperAPI +def get_unique_channels( + image_sample: List[torch.Tensor], + num_channels: int, + num_classes: int = None, +) -> torch.Tensor: + """Returns a tensor of unique channel values from a list of images. + Args: + image_sample: A list of images of dimensions [C x H x W] or [H x W], where C is the channel dimension + num_channels: The expected number of channels + num_classes: The expected number of classes or None + + Return: + channel_class_map: A tensor mapping channel values to classes, where dim=0 is the class. + """ + n_images = 0 + no_new_class = 0 + channel_class_map = None + for img in image_sample: + if img.ndim < 2: + raise ValueError("Invalid image dimensions {img.ndim}") + if img.ndim == 2: + img = img.unsqueeze(0) + if num_channels == 1 and num_channels_in_image(img) != 1: + img = grayscale(img) + if num_classes == 2 and num_channels_in_image(img) == 1: + img = img.type(torch.float32) / 255 + img = img.round() * 255 + img = img.type(torch.uint8) + + img = img.flatten(1, 2) + img = img.permute(1, 0) + uniq_chans = img.unique(dim=0) + + if channel_class_map is None: + channel_class_map = uniq_chans + else: + channel_class_map = torch.concat((channel_class_map, uniq_chans)).unique(dim=0) + if channel_class_map.shape[0] > IMAGE_MAX_CLASSES: + raise ValueError( + f"Images inferred num classes {channel_class_map.shape[0]} exceeds " f"max classes {IMAGE_MAX_CLASSES}." + ) + + n_images += 1 + if n_images % 25 == 0: + logger.info(f"Processed the first {n_images} images inferring {channel_class_map.shape[0]} classes...") + + if channel_class_map.shape[0] == uniq_chans.shape[0]: + no_new_class += 1 + if no_new_class >= 4 and channel_class_map.shape[0] == num_classes: + break # early loop exit + else: + no_new_class = 0 + + logger.info(f"Inferred {channel_class_map.shape[0]} classes from the first {n_images} images.") + return channel_class_map.type(torch.uint8) + + +@DeveloperAPI +def get_class_mask_from_image( + channel_class_map: torch.Tensor, + img: torch.Tensor, +) -> torch.Tensor: + """Returns a masked image where each mask value is the channel class of the input. + Args: + channel_class_map: A tensor mapping channel values to classes, where dim=0 is the class. + img: An input image of dimensions [C x H x W] or [H x W], where C is the channel dimension + + Return: + [mask] A masked image of dimensions [H x W] where each value is the channel class of the input + """ + num_classes = channel_class_map.shape[0] + mask = torch.full((img.shape[-2], img.shape[-1]), num_classes, dtype=torch.uint8) + if img.ndim == 2: + img = img.unsqueeze(0) + if num_classes == 2 and num_channels_in_image(img) == 1: + img = img.type(torch.float32) / 255 + img = img.round() * 255 + img = img.type(torch.uint8) + img = img.permute(1, 2, 0) + for nclass, value in enumerate(channel_class_map): + mask[(img == value).all(-1)] = nclass + + if torch.any(mask.ge(num_classes)): + raise ValueError( + f"Image channel could not be mapped to a class because an unknown channel value was detected. " + f"{num_classes} classes were inferred from the first set of images. This image has a channel " + f"value that was not previously seen in the first set of images. Check preprocessing parameters " + f"for image resizing, num channels, num classes and num samples. Image resizing may affect " + f"channel values. " + ) + + return mask + + +@DeveloperAPI +def get_image_from_class_mask( + channel_class_map: torch.Tensor, + mask: np.ndarray, +) -> np.ndarray: + """Returns an image with channel values determined from a corresponding mask. + Args: + channel_class_map: An tensor mapping channel values to classes, where dim=0 is the class. + mask: A masked image of dimensions [H x W] where each value is the channel class of the final image + + Return: + [img] An image of dimensions [C x H x W], where C is the channel dimension + """ + mask = torch.from_numpy(mask) + img = torch.zeros(channel_class_map.shape[1], mask.shape[-2], mask.shape[-1], dtype=torch.uint8) + img = img.permute(1, 2, 0) + mask = mask.unsqueeze(0) + mask = mask.permute(1, 2, 0) + for nclass, value in enumerate(channel_class_map): + img[(mask == nclass).all(-1)] = value + img = img.permute(2, 0, 1) + + return img.numpy() + + @DeveloperAPI def to_tuple(v: Union[int, Tuple[int, int]]) -> Tuple[int, int]: """Converts int or tuple to tuple of ints.""" diff --git a/tests/ludwig/decoders/test_image_decoder.py b/tests/ludwig/decoders/test_image_decoder.py new file mode 100644 index 00000000000..2f48591da7b --- /dev/null +++ b/tests/ludwig/decoders/test_image_decoder.py @@ -0,0 +1,41 @@ +import pytest +import torch + +from ludwig.constants import ENCODER_OUTPUT, ENCODER_OUTPUT_STATE, HIDDEN, LOGITS +from ludwig.decoders.image_decoders import UNetDecoder +from ludwig.encoders.image.base import UNetEncoder +from ludwig.utils.misc_utils import set_random_seed +from tests.integration_tests.parameter_update_utils import check_module_parameters_updated + +RANDOM_SEED = 1919 + + +@pytest.mark.parametrize("height,width,num_channels,num_classes", [(224, 224, 1, 2), (224, 224, 3, 8)]) +@pytest.mark.parametrize("batch_size", [4, 1]) +def test_unet_decoder(height, width, num_channels, num_classes, batch_size): + # make repeatable + set_random_seed(RANDOM_SEED) + + unet_encoder = UNetEncoder(height=height, width=width, num_channels=num_channels) + inputs = torch.rand(batch_size, num_channels, height, width) + encoder_outputs = unet_encoder(inputs) + assert encoder_outputs[ENCODER_OUTPUT].shape[1:] == unet_encoder.output_shape + assert len(encoder_outputs[ENCODER_OUTPUT_STATE]) == 4 + + hidden = torch.reshape(encoder_outputs[ENCODER_OUTPUT], [batch_size, -1]) + + unet_decoder = UNetDecoder(hidden.size(dim=1), height, width, 1, num_classes) + combiner_outputs = { + HIDDEN: hidden, + ENCODER_OUTPUT_STATE: encoder_outputs[ENCODER_OUTPUT_STATE].copy(), # create a copy + } + + output = unet_decoder(combiner_outputs, target=None) + + assert list(output[LOGITS].size()) == [batch_size, num_classes, height, width] + + # check for parameter updating + target = torch.randn(output[LOGITS].shape) + combiner_outputs[ENCODER_OUTPUT_STATE] = encoder_outputs[ENCODER_OUTPUT_STATE] # restore state + fpc, tpc, upc, not_updated = check_module_parameters_updated(unet_decoder, (combiner_outputs, None), target) + assert upc == tpc, f"Failed to update parameters. Parameters not updated: {not_updated}" diff --git a/tests/ludwig/encoders/test_image_encoders.py b/tests/ludwig/encoders/test_image_encoders.py index 19e8afd1b88..c1f74208e3a 100644 --- a/tests/ludwig/encoders/test_image_encoders.py +++ b/tests/ludwig/encoders/test_image_encoders.py @@ -4,7 +4,7 @@ import torch from ludwig.constants import ENCODER_OUTPUT -from ludwig.encoders.image.base import MLPMixerEncoder, ResNetEncoder, Stacked2DCNN, ViTEncoder +from ludwig.encoders.image.base import MLPMixerEncoder, ResNetEncoder, Stacked2DCNN, UNetEncoder, ViTEncoder from ludwig.encoders.image.torchvision import ( TVAlexNetEncoder, TVConvNeXtEncoder, @@ -115,6 +115,23 @@ def test_vit_encoder(image_size: int, num_channels: int, use_pretrained: bool): assert tpc == upc, f"Not all expected parameters updated. Parameters not updated {not_updated}." +@pytest.mark.parametrize("height,width,num_channels", [(224, 224, 1), (224, 224, 3)]) +def test_unet_encoder(height: int, width: int, num_channels: int): + # make repeatable + set_random_seed(RANDOM_SEED) + + unet_encoder = UNetEncoder(height=height, width=width, num_channels=num_channels) + inputs = torch.rand(2, num_channels, height, width) + outputs = unet_encoder(inputs) + assert outputs[ENCODER_OUTPUT].shape[1:] == unet_encoder.output_shape + + # check for parameter updating + target = torch.randn(outputs[ENCODER_OUTPUT].shape) + fpc, tpc, upc, not_updated = check_module_parameters_updated(unet_encoder, (inputs,), target) + + assert tpc == upc, f"Not all expected parameters updated. Parameters not updated {not_updated}." + + @pytest.mark.parametrize("trainable", [True, False]) @pytest.mark.parametrize("saved_weights_in_checkpoint", [True, False]) @pytest.mark.parametrize( diff --git a/tests/ludwig/features/test_image_feature.py b/tests/ludwig/features/test_image_feature.py index bc95663f1f7..ac01adc0216 100644 --- a/tests/ludwig/features/test_image_feature.py +++ b/tests/ludwig/features/test_image_feature.py @@ -4,12 +4,22 @@ import pytest import torch -from ludwig.constants import BFILL, CROP_OR_PAD, ENCODER, ENCODER_OUTPUT, INTERPOLATE, TYPE -from ludwig.features.image_feature import _ImagePreprocessing, ImageInputFeature -from ludwig.schema.features.image_feature import ImageInputFeatureConfig +from ludwig.constants import ( + BFILL, + CROP_OR_PAD, + ENCODER, + ENCODER_OUTPUT, + ENCODER_OUTPUT_STATE, + INTERPOLATE, + LOGITS, + TYPE, +) +from ludwig.features.image_feature import _ImagePreprocessing, ImageInputFeature, ImageOutputFeature +from ludwig.schema.features.image_feature import ImageInputFeatureConfig, ImageOutputFeatureConfig from ludwig.schema.utils import load_config_with_kwargs from ludwig.utils.misc_utils import merge_dict from ludwig.utils.torch_utils import get_torch_device +from tests.integration_tests.utils import image_feature BATCH_SIZE = 2 DEVICE = get_torch_device() @@ -112,6 +122,76 @@ def test_image_input_feature(image_config: Dict, encoder: str, height: int, widt # ) +@pytest.mark.parametrize( + "encoder, decoder, height, width, num_channels, num_classes", + [ + ("unet", "unet", 128, 128, 3, 2), + ("unet", "unet", 32, 32, 3, 7), + ], +) +def test_image_output_feature( + encoder: str, + decoder: str, + height: int, + width: int, + num_channels: int, + num_classes: int, +) -> None: + # setup image input feature definition + input_feature_def = image_feature( + folder=".", + encoder={ + "type": encoder, + "height": height, + "width": width, + "num_channels": num_channels, + }, + ) + # create image input feature object + feature_cls = ImageInputFeature + schema_cls = ImageInputFeatureConfig + input_config = schema_cls.from_dict(input_feature_def) + input_feature_obj = feature_cls(input_config).to(DEVICE) + + # check one forward pass through input feature + input_tensor = torch.rand(size=(BATCH_SIZE, num_channels, height, width), dtype=torch.float32).to(DEVICE) + + encoder_output = input_feature_obj(input_tensor) + assert encoder_output[ENCODER_OUTPUT].shape == (BATCH_SIZE, *input_feature_obj.output_shape) + if encoder == "unet": + assert len(encoder_output[ENCODER_OUTPUT_STATE]) == 4 + + hidden = torch.reshape(encoder_output[ENCODER_OUTPUT], [BATCH_SIZE, -1]) + + # setup image output feature definition + output_feature_def = image_feature( + folder=".", + decoder={ + "type": decoder, + "height": height, + "width": width, + "num_channels": num_channels, + "num_classes": num_classes, + }, + input_size=hidden.size(dim=1), + ) + # create image output feature object + feature_cls = ImageOutputFeature + schema_cls = ImageOutputFeatureConfig + output_config = schema_cls.from_dict(output_feature_def) + output_feature_obj = feature_cls(output_config, {}).to(DEVICE) + + combiner_outputs = { + "combiner_output": hidden, + ENCODER_OUTPUT_STATE: encoder_output[ENCODER_OUTPUT_STATE], + } + + image_output = output_feature_obj(combiner_outputs, {}) + + assert LOGITS in image_output + assert image_output[LOGITS].size() == torch.Size([BATCH_SIZE, num_classes, height, width]) + + def test_image_preproc_module_bad_num_channels(): metadata = { "preprocessing": { @@ -128,6 +208,8 @@ def test_image_preproc_module_bad_num_channels(): "height": 12, "width": 12, "num_channels": 2, + "num_classes": 0, + "channel_class_map": [], }, "reshape": (2, 12, 12), } @@ -155,6 +237,8 @@ def test_image_preproc_module_list_of_tensors(resize_method, num_channels, num_c "height": 12, "width": 12, "num_channels": num_channels_expected, + "num_classes": 0, + "channel_class_map": [], }, "reshape": (num_channels_expected, 12, 12), } @@ -183,6 +267,8 @@ def test_image_preproc_module_tensor(resize_method, num_channels, num_channels_e "height": 12, "width": 12, "num_channels": num_channels_expected, + "num_classes": 0, + "channel_class_map": [], }, "reshape": (num_channels_expected, 12, 12), } @@ -191,3 +277,39 @@ def test_image_preproc_module_tensor(resize_method, num_channels, num_channels_e res = module(torch.rand(2, num_channels, 10, 10)) assert res.shape == torch.Size((2, num_channels_expected, 12, 12)) + + +@pytest.mark.parametrize(["height", "width"], [(224, 224), (32, 32)]) +def test_image_preproc_module_class_map(height, width): + metadata = { + "preprocessing": { + "num_processes": 1, + "resize_method": CROP_OR_PAD, + "infer_image_num_channels": True, + "infer_image_dimensions": True, + "infer_image_max_height": height, + "infer_image_max_width": width, + "infer_image_sample_size": 100, + "infer_image_num_classes": True, + "height": height, + "width": width, + "num_channels": 3, + "num_classes": 8, + "channel_class_map": [ + [40, 40, 40], + [40, 40, 41], + [40, 41, 40], + [40, 41, 41], + [41, 40, 40], + [41, 40, 41], + [41, 41, 40], + [41, 41, 41], + ], + }, + } + module = _ImagePreprocessing(metadata) + + res = module(torch.randint(40, 42, (2, 3, height, width))) + + assert res.shape == torch.Size((2, height, width)) + assert torch.all(res.ge(0)) and torch.all(res.le(7))