forked from MIC-DKFZ/nnUNet
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
14 changed files
with
21,342 additions
and
171 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,235 @@ | ||
from typing import Union, Type, List, Tuple | ||
|
||
import torch | ||
from dynamic_network_architectures.building_blocks.helper import convert_conv_op_to_dim | ||
from dynamic_network_architectures.building_blocks.plain_conv_encoder import PlainConvEncoder | ||
from dynamic_network_architectures.building_blocks.residual import BasicBlockD, BottleneckD | ||
from dynamic_network_architectures.building_blocks.residual_encoders import ResidualEncoder | ||
from dynamic_network_architectures.building_blocks.unet_residual_decoder import UNetResDecoder | ||
from dynamic_network_architectures.initialization.weight_init import InitWeights_He | ||
from dynamic_network_architectures.initialization.weight_init import init_last_bn_before_add_to_0 | ||
from torch import nn | ||
from torch.nn.modules.conv import _ConvNd | ||
from torch.nn.modules.dropout import _DropoutNd | ||
import numpy as np | ||
|
||
NUM_ORGANS = 10 + 1 # 10 organs + 1 background | ||
|
||
class ResidualEncoderUNetOrgan(nn.Module): | ||
def __init__(self, | ||
input_channels: int, | ||
n_stages: int, | ||
features_per_stage: Union[int, List[int], Tuple[int, ...]], | ||
conv_op: Type[_ConvNd], | ||
kernel_sizes: Union[int, List[int], Tuple[int, ...]], | ||
strides: Union[int, List[int], Tuple[int, ...]], | ||
n_blocks_per_stage: Union[int, List[int], Tuple[int, ...]], | ||
num_classes: int, | ||
n_conv_per_stage_decoder: Union[int, Tuple[int, ...], List[int]], | ||
conv_bias: bool = False, | ||
norm_op: Union[None, Type[nn.Module]] = None, | ||
norm_op_kwargs: dict = None, | ||
dropout_op: Union[None, Type[_DropoutNd]] = None, | ||
dropout_op_kwargs: dict = None, | ||
nonlin: Union[None, Type[torch.nn.Module]] = None, | ||
nonlin_kwargs: dict = None, | ||
deep_supervision: bool = False, | ||
block: Union[Type[BasicBlockD], Type[BottleneckD]] = BasicBlockD, | ||
bottleneck_channels: Union[int, List[int], Tuple[int, ...]] = None, | ||
stem_channels: int = None | ||
): | ||
super().__init__() | ||
if isinstance(n_blocks_per_stage, int): | ||
n_blocks_per_stage = [n_blocks_per_stage] * n_stages | ||
if isinstance(n_conv_per_stage_decoder, int): | ||
n_conv_per_stage_decoder = [n_conv_per_stage_decoder] * (n_stages - 1) | ||
assert len(n_blocks_per_stage) == n_stages, "n_blocks_per_stage must have as many entries as we have " \ | ||
f"resolution stages. here: {n_stages}. " \ | ||
f"n_blocks_per_stage: {n_blocks_per_stage}" | ||
assert len(n_conv_per_stage_decoder) == (n_stages - 1), "n_conv_per_stage_decoder must have one less entries " \ | ||
f"as we have resolution stages. here: {n_stages} " \ | ||
f"stages, so it should have {n_stages - 1} entries. " \ | ||
f"n_conv_per_stage_decoder: {n_conv_per_stage_decoder}" | ||
self.encoder = ResidualEncoder(input_channels, n_stages, features_per_stage, conv_op, kernel_sizes, strides, | ||
n_blocks_per_stage, conv_bias, norm_op, norm_op_kwargs, dropout_op, | ||
dropout_op_kwargs, nonlin, nonlin_kwargs, block, bottleneck_channels, | ||
return_skips=True, disable_default_stem=False, stem_channels=stem_channels) | ||
self.decoder = UNetDecoder(self.encoder, num_classes, n_conv_per_stage_decoder, deep_supervision) | ||
|
||
def forward(self, x, organ=False): | ||
skips = self.encoder(x) | ||
return self.decoder(skips, organ) | ||
|
||
def compute_conv_feature_map_size(self, input_size): | ||
assert len(input_size) == convert_conv_op_to_dim(self.encoder.conv_op), "just give the image size without color/feature channels or " \ | ||
"batch channel. Do not give input_size=(b, c, x, y(, z)). " \ | ||
"Give input_size=(x, y(, z))!" | ||
return self.encoder.compute_conv_feature_map_size(input_size) + self.decoder.compute_conv_feature_map_size(input_size) | ||
|
||
@staticmethod | ||
def initialize(module): | ||
InitWeights_He(1e-2)(module) | ||
init_last_bn_before_add_to_0(module) | ||
|
||
|
||
from dynamic_network_architectures.building_blocks.simple_conv_blocks import StackedConvBlocks | ||
from dynamic_network_architectures.building_blocks.helper import get_matching_convtransp | ||
|
||
|
||
class UNetDecoder(nn.Module): | ||
def __init__(self, | ||
encoder: Union[PlainConvEncoder, ResidualEncoder], | ||
num_classes: int, | ||
n_conv_per_stage: Union[int, Tuple[int, ...], List[int]], | ||
deep_supervision, | ||
nonlin_first: bool = False, | ||
norm_op: Union[None, Type[nn.Module]] = None, | ||
norm_op_kwargs: dict = None, | ||
dropout_op: Union[None, Type[_DropoutNd]] = None, | ||
dropout_op_kwargs: dict = None, | ||
nonlin: Union[None, Type[torch.nn.Module]] = None, | ||
nonlin_kwargs: dict = None, | ||
conv_bias: bool = None | ||
): | ||
""" | ||
This class needs the skips of the encoder as input in its forward. | ||
the encoder goes all the way to the bottleneck, so that's where the decoder picks up. stages in the decoder | ||
are sorted by order of computation, so the first stage has the lowest resolution and takes the bottleneck | ||
features and the lowest skip as inputs | ||
the decoder has two (three) parts in each stage: | ||
1) conv transpose to upsample the feature maps of the stage below it (or the bottleneck in case of the first stage) | ||
2) n_conv_per_stage conv blocks to let the two inputs get to know each other and merge | ||
3) (optional if deep_supervision=True) a segmentation output Todo: enable upsample logits? | ||
:param encoder: | ||
:param num_classes: | ||
:param n_conv_per_stage: | ||
:param deep_supervision: | ||
""" | ||
super().__init__() | ||
self.deep_supervision = deep_supervision | ||
self.encoder = encoder | ||
self.num_classes = num_classes | ||
n_stages_encoder = len(encoder.output_channels) | ||
if isinstance(n_conv_per_stage, int): | ||
n_conv_per_stage = [n_conv_per_stage] * (n_stages_encoder - 1) | ||
assert len(n_conv_per_stage) == n_stages_encoder - 1, "n_conv_per_stage must have as many entries as we have " \ | ||
"resolution stages - 1 (n_stages in encoder - 1), " \ | ||
"here: %d" % n_stages_encoder | ||
|
||
transpconv_op = get_matching_convtransp(conv_op=encoder.conv_op) | ||
conv_bias = encoder.conv_bias if conv_bias is None else conv_bias | ||
norm_op = encoder.norm_op if norm_op is None else norm_op | ||
norm_op_kwargs = encoder.norm_op_kwargs if norm_op_kwargs is None else norm_op_kwargs | ||
dropout_op = encoder.dropout_op if dropout_op is None else dropout_op | ||
dropout_op_kwargs = encoder.dropout_op_kwargs if dropout_op_kwargs is None else dropout_op_kwargs | ||
nonlin = encoder.nonlin if nonlin is None else nonlin | ||
nonlin_kwargs = encoder.nonlin_kwargs if nonlin_kwargs is None else nonlin_kwargs | ||
|
||
|
||
# we start with the bottleneck and work out way up | ||
stages = [] | ||
transpconvs = [] | ||
seg_layers = [] | ||
organ_seg_layers = [] | ||
for s in range(1, n_stages_encoder): | ||
input_features_below = encoder.output_channels[-s] | ||
input_features_skip = encoder.output_channels[-(s + 1)] | ||
stride_for_transpconv = encoder.strides[-s] | ||
transpconvs.append(transpconv_op( | ||
input_features_below, input_features_skip, stride_for_transpconv, stride_for_transpconv, | ||
bias=conv_bias | ||
)) | ||
# input features to conv is 2x input_features_skip (concat input_features_skip with transpconv output) | ||
stages.append(StackedConvBlocks( | ||
n_conv_per_stage[s-1], encoder.conv_op, 2 * input_features_skip, input_features_skip, | ||
encoder.kernel_sizes[-(s + 1)], 1, | ||
conv_bias, | ||
norm_op, | ||
norm_op_kwargs, | ||
dropout_op, | ||
dropout_op_kwargs, | ||
nonlin, | ||
nonlin_kwargs, | ||
nonlin_first | ||
)) | ||
|
||
# we always build the deep supervision outputs so that we can always load parameters. If we don't do this | ||
# then a model trained with deep_supervision=True could not easily be loaded at inference time where | ||
# deep supervision is not needed. It's just a convenience thing | ||
seg_layers.append(encoder.conv_op(input_features_skip, num_classes, 1, 1, 0, bias=True)) | ||
organ_seg_layers.append(encoder.conv_op(input_features_skip, NUM_ORGANS, 1, 1, 0, bias=True)) # 11 classes for organ segmentation | ||
|
||
self.stages = nn.ModuleList(stages) | ||
self.transpconvs = nn.ModuleList(transpconvs) | ||
self.seg_layers = nn.ModuleList(seg_layers) | ||
self.organ_seg_layers = nn.ModuleList(organ_seg_layers) | ||
|
||
def forward(self, skips, organ=False): | ||
""" | ||
we expect to get the skips in the order they were computed, so the bottleneck should be the last entry | ||
:param skips: | ||
:return: | ||
""" | ||
lres_input = skips[-1] | ||
seg_outputs = [] | ||
organ_seg_outputs = [] | ||
for s in range(len(self.stages)): | ||
x = self.transpconvs[s](lres_input) | ||
x = torch.cat((x, skips[-(s+2)]), 1) | ||
x = self.stages[s](x) | ||
if self.deep_supervision: | ||
seg_outputs.append(self.seg_layers[s](x)) | ||
if organ: | ||
organ_seg_outputs.append(self.organ_seg_layers[s](x)) | ||
elif s == (len(self.stages) - 1): | ||
seg_outputs.append(self.seg_layers[-1](x)) | ||
if organ: | ||
organ_seg_outputs.append(self.organ_seg_layers[-1](x)) | ||
lres_input = x | ||
|
||
# invert seg outputs so that the largest segmentation prediction is returned first | ||
seg_outputs = seg_outputs[::-1] | ||
organ_seg_outputs = organ_seg_outputs[::-1] | ||
|
||
if not self.deep_supervision: | ||
r = seg_outputs[0] | ||
if organ: | ||
o = organ_seg_outputs[0] | ||
else: | ||
r = seg_outputs | ||
if organ: | ||
o = organ_seg_outputs | ||
if organ: | ||
return r, o | ||
else: | ||
return r | ||
|
||
def compute_conv_feature_map_size(self, input_size): | ||
""" | ||
IMPORTANT: input_size is the input_size of the encoder! | ||
:param input_size: | ||
:return: | ||
""" | ||
# first we need to compute the skip sizes. Skip bottleneck because all output feature maps of our ops will at | ||
# least have the size of the skip above that (therefore -1) | ||
skip_sizes = [] | ||
for s in range(len(self.encoder.strides) - 1): | ||
skip_sizes.append([i // j for i, j in zip(input_size, self.encoder.strides[s])]) | ||
input_size = skip_sizes[-1] | ||
# print(skip_sizes) | ||
|
||
assert len(skip_sizes) == len(self.stages) | ||
|
||
# our ops are the other way around, so let's match things up | ||
output = np.int64(0) | ||
for s in range(len(self.stages)): | ||
# print(skip_sizes[-(s+1)], self.encoder.output_channels[-(s+2)]) | ||
# conv blocks | ||
output += self.stages[s].compute_conv_feature_map_size(skip_sizes[-(s+1)]) | ||
# trans conv | ||
output += np.prod([self.encoder.output_channels[-(s+2)], *skip_sizes[-(s+1)]], dtype=np.int64) | ||
# segmentation | ||
if self.deep_supervision or (s == (len(self.stages) - 1)): | ||
output += np.prod([self.num_classes, *skip_sizes[-(s+1)]], dtype=np.int64) | ||
return output |
Empty file.
Oops, something went wrong.