Skip to content

Commit

Permalink
upd for autopet3
Browse files Browse the repository at this point in the history
  • Loading branch information
fitzjalen committed Oct 26, 2024
1 parent d145bd4 commit 6ba87c3
Show file tree
Hide file tree
Showing 14 changed files with 21,342 additions and 171 deletions.
15 changes: 15 additions & 0 deletions configs/mosaic_planner.json
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,20 @@
"inherits_from": "3d_fullres_no_resampling",
"spacing": [3.0, 3.0, 3.0],
"data_identifier": "nnUNetPlans_3d_fullres_lowres_NoRsmp"
},
"3d_fullres_highres": {
"inherits_from": "3d_fullres",
"spacing": [1.0, 1.0, 1.0],
"data_identifier": "nnUNetPlans_3d_fullres_highres"
},
"3d_fullres_stdres": {
"inherits_from": "3d_fullres",
"spacing": [1.5, 1.5, 1.5],
"data_identifier": "nnUNetPlans_3d_fullres_stdres"
},
"3d_fullres_lowres": {
"inherits_from": "3d_fullres",
"spacing": [3.0, 3.0, 3.0],
"data_identifier": "nnUNetPlans_3d_fullres_lowres"
}
}
235 changes: 235 additions & 0 deletions nnunetv2/architecture/ResidualEncoderUNetOrgan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
from typing import Union, Type, List, Tuple

import torch
from dynamic_network_architectures.building_blocks.helper import convert_conv_op_to_dim
from dynamic_network_architectures.building_blocks.plain_conv_encoder import PlainConvEncoder
from dynamic_network_architectures.building_blocks.residual import BasicBlockD, BottleneckD
from dynamic_network_architectures.building_blocks.residual_encoders import ResidualEncoder
from dynamic_network_architectures.building_blocks.unet_residual_decoder import UNetResDecoder
from dynamic_network_architectures.initialization.weight_init import InitWeights_He
from dynamic_network_architectures.initialization.weight_init import init_last_bn_before_add_to_0
from torch import nn
from torch.nn.modules.conv import _ConvNd
from torch.nn.modules.dropout import _DropoutNd
import numpy as np

NUM_ORGANS = 10 + 1 # 10 organs + 1 background

class ResidualEncoderUNetOrgan(nn.Module):
def __init__(self,
input_channels: int,
n_stages: int,
features_per_stage: Union[int, List[int], Tuple[int, ...]],
conv_op: Type[_ConvNd],
kernel_sizes: Union[int, List[int], Tuple[int, ...]],
strides: Union[int, List[int], Tuple[int, ...]],
n_blocks_per_stage: Union[int, List[int], Tuple[int, ...]],
num_classes: int,
n_conv_per_stage_decoder: Union[int, Tuple[int, ...], List[int]],
conv_bias: bool = False,
norm_op: Union[None, Type[nn.Module]] = None,
norm_op_kwargs: dict = None,
dropout_op: Union[None, Type[_DropoutNd]] = None,
dropout_op_kwargs: dict = None,
nonlin: Union[None, Type[torch.nn.Module]] = None,
nonlin_kwargs: dict = None,
deep_supervision: bool = False,
block: Union[Type[BasicBlockD], Type[BottleneckD]] = BasicBlockD,
bottleneck_channels: Union[int, List[int], Tuple[int, ...]] = None,
stem_channels: int = None
):
super().__init__()
if isinstance(n_blocks_per_stage, int):
n_blocks_per_stage = [n_blocks_per_stage] * n_stages
if isinstance(n_conv_per_stage_decoder, int):
n_conv_per_stage_decoder = [n_conv_per_stage_decoder] * (n_stages - 1)
assert len(n_blocks_per_stage) == n_stages, "n_blocks_per_stage must have as many entries as we have " \
f"resolution stages. here: {n_stages}. " \
f"n_blocks_per_stage: {n_blocks_per_stage}"
assert len(n_conv_per_stage_decoder) == (n_stages - 1), "n_conv_per_stage_decoder must have one less entries " \
f"as we have resolution stages. here: {n_stages} " \
f"stages, so it should have {n_stages - 1} entries. " \
f"n_conv_per_stage_decoder: {n_conv_per_stage_decoder}"
self.encoder = ResidualEncoder(input_channels, n_stages, features_per_stage, conv_op, kernel_sizes, strides,
n_blocks_per_stage, conv_bias, norm_op, norm_op_kwargs, dropout_op,
dropout_op_kwargs, nonlin, nonlin_kwargs, block, bottleneck_channels,
return_skips=True, disable_default_stem=False, stem_channels=stem_channels)
self.decoder = UNetDecoder(self.encoder, num_classes, n_conv_per_stage_decoder, deep_supervision)

def forward(self, x, organ=False):
skips = self.encoder(x)
return self.decoder(skips, organ)

def compute_conv_feature_map_size(self, input_size):
assert len(input_size) == convert_conv_op_to_dim(self.encoder.conv_op), "just give the image size without color/feature channels or " \
"batch channel. Do not give input_size=(b, c, x, y(, z)). " \
"Give input_size=(x, y(, z))!"
return self.encoder.compute_conv_feature_map_size(input_size) + self.decoder.compute_conv_feature_map_size(input_size)

@staticmethod
def initialize(module):
InitWeights_He(1e-2)(module)
init_last_bn_before_add_to_0(module)


from dynamic_network_architectures.building_blocks.simple_conv_blocks import StackedConvBlocks
from dynamic_network_architectures.building_blocks.helper import get_matching_convtransp


class UNetDecoder(nn.Module):
def __init__(self,
encoder: Union[PlainConvEncoder, ResidualEncoder],
num_classes: int,
n_conv_per_stage: Union[int, Tuple[int, ...], List[int]],
deep_supervision,
nonlin_first: bool = False,
norm_op: Union[None, Type[nn.Module]] = None,
norm_op_kwargs: dict = None,
dropout_op: Union[None, Type[_DropoutNd]] = None,
dropout_op_kwargs: dict = None,
nonlin: Union[None, Type[torch.nn.Module]] = None,
nonlin_kwargs: dict = None,
conv_bias: bool = None
):
"""
This class needs the skips of the encoder as input in its forward.
the encoder goes all the way to the bottleneck, so that's where the decoder picks up. stages in the decoder
are sorted by order of computation, so the first stage has the lowest resolution and takes the bottleneck
features and the lowest skip as inputs
the decoder has two (three) parts in each stage:
1) conv transpose to upsample the feature maps of the stage below it (or the bottleneck in case of the first stage)
2) n_conv_per_stage conv blocks to let the two inputs get to know each other and merge
3) (optional if deep_supervision=True) a segmentation output Todo: enable upsample logits?
:param encoder:
:param num_classes:
:param n_conv_per_stage:
:param deep_supervision:
"""
super().__init__()
self.deep_supervision = deep_supervision
self.encoder = encoder
self.num_classes = num_classes
n_stages_encoder = len(encoder.output_channels)
if isinstance(n_conv_per_stage, int):
n_conv_per_stage = [n_conv_per_stage] * (n_stages_encoder - 1)
assert len(n_conv_per_stage) == n_stages_encoder - 1, "n_conv_per_stage must have as many entries as we have " \
"resolution stages - 1 (n_stages in encoder - 1), " \
"here: %d" % n_stages_encoder

transpconv_op = get_matching_convtransp(conv_op=encoder.conv_op)
conv_bias = encoder.conv_bias if conv_bias is None else conv_bias
norm_op = encoder.norm_op if norm_op is None else norm_op
norm_op_kwargs = encoder.norm_op_kwargs if norm_op_kwargs is None else norm_op_kwargs
dropout_op = encoder.dropout_op if dropout_op is None else dropout_op
dropout_op_kwargs = encoder.dropout_op_kwargs if dropout_op_kwargs is None else dropout_op_kwargs
nonlin = encoder.nonlin if nonlin is None else nonlin
nonlin_kwargs = encoder.nonlin_kwargs if nonlin_kwargs is None else nonlin_kwargs


# we start with the bottleneck and work out way up
stages = []
transpconvs = []
seg_layers = []
organ_seg_layers = []
for s in range(1, n_stages_encoder):
input_features_below = encoder.output_channels[-s]
input_features_skip = encoder.output_channels[-(s + 1)]
stride_for_transpconv = encoder.strides[-s]
transpconvs.append(transpconv_op(
input_features_below, input_features_skip, stride_for_transpconv, stride_for_transpconv,
bias=conv_bias
))
# input features to conv is 2x input_features_skip (concat input_features_skip with transpconv output)
stages.append(StackedConvBlocks(
n_conv_per_stage[s-1], encoder.conv_op, 2 * input_features_skip, input_features_skip,
encoder.kernel_sizes[-(s + 1)], 1,
conv_bias,
norm_op,
norm_op_kwargs,
dropout_op,
dropout_op_kwargs,
nonlin,
nonlin_kwargs,
nonlin_first
))

# we always build the deep supervision outputs so that we can always load parameters. If we don't do this
# then a model trained with deep_supervision=True could not easily be loaded at inference time where
# deep supervision is not needed. It's just a convenience thing
seg_layers.append(encoder.conv_op(input_features_skip, num_classes, 1, 1, 0, bias=True))
organ_seg_layers.append(encoder.conv_op(input_features_skip, NUM_ORGANS, 1, 1, 0, bias=True)) # 11 classes for organ segmentation

self.stages = nn.ModuleList(stages)
self.transpconvs = nn.ModuleList(transpconvs)
self.seg_layers = nn.ModuleList(seg_layers)
self.organ_seg_layers = nn.ModuleList(organ_seg_layers)

def forward(self, skips, organ=False):
"""
we expect to get the skips in the order they were computed, so the bottleneck should be the last entry
:param skips:
:return:
"""
lres_input = skips[-1]
seg_outputs = []
organ_seg_outputs = []
for s in range(len(self.stages)):
x = self.transpconvs[s](lres_input)
x = torch.cat((x, skips[-(s+2)]), 1)
x = self.stages[s](x)
if self.deep_supervision:
seg_outputs.append(self.seg_layers[s](x))
if organ:
organ_seg_outputs.append(self.organ_seg_layers[s](x))
elif s == (len(self.stages) - 1):
seg_outputs.append(self.seg_layers[-1](x))
if organ:
organ_seg_outputs.append(self.organ_seg_layers[-1](x))
lres_input = x

# invert seg outputs so that the largest segmentation prediction is returned first
seg_outputs = seg_outputs[::-1]
organ_seg_outputs = organ_seg_outputs[::-1]

if not self.deep_supervision:
r = seg_outputs[0]
if organ:
o = organ_seg_outputs[0]
else:
r = seg_outputs
if organ:
o = organ_seg_outputs
if organ:
return r, o
else:
return r

def compute_conv_feature_map_size(self, input_size):
"""
IMPORTANT: input_size is the input_size of the encoder!
:param input_size:
:return:
"""
# first we need to compute the skip sizes. Skip bottleneck because all output feature maps of our ops will at
# least have the size of the skip above that (therefore -1)
skip_sizes = []
for s in range(len(self.encoder.strides) - 1):
skip_sizes.append([i // j for i, j in zip(input_size, self.encoder.strides[s])])
input_size = skip_sizes[-1]
# print(skip_sizes)

assert len(skip_sizes) == len(self.stages)

# our ops are the other way around, so let's match things up
output = np.int64(0)
for s in range(len(self.stages)):
# print(skip_sizes[-(s+1)], self.encoder.output_channels[-(s+2)])
# conv blocks
output += self.stages[s].compute_conv_feature_map_size(skip_sizes[-(s+1)])
# trans conv
output += np.prod([self.encoder.output_channels[-(s+2)], *skip_sizes[-(s+1)]], dtype=np.int64)
# segmentation
if self.deep_supervision or (s == (len(self.stages) - 1)):
output += np.prod([self.num_classes, *skip_sizes[-(s+1)]], dtype=np.int64)
return output
Empty file.
Loading

0 comments on commit 6ba87c3

Please sign in to comment.