Skip to content

Commit

Permalink
Merge branch 'MIC-DKFZ:master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
fitzjalen authored May 29, 2024
2 parents d783216 + d12a0c1 commit 62fe253
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 6 deletions.
10 changes: 8 additions & 2 deletions nnunetv2/training/loss/compound_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,20 @@ def __init__(self, bce_kwargs, soft_dice_kwargs, weight_ce=1, weight_dice=1, use
def forward(self, net_output: torch.Tensor, target: torch.Tensor):
if self.use_ignore_label:
# target is one hot encoded here. invert it so that it is True wherever we can compute the loss
mask = (1 - target[:, -1:]).bool()
if target.dtype == torch.bool:
mask = ~target[:, -1:]
else:
mask = (1 - target[:, -1:]).bool()
# remove ignore channel now that we have the mask
target_regions = torch.clone(target[:, :-1])
# why did we use clone in the past? Should have documented that...
# target_regions = torch.clone(target[:, :-1])
target_regions = target[:, :-1]
else:
target_regions = target
mask = None

dc_loss = self.dc(net_output, target_regions, loss_mask=mask)
target_regions = target_regions.float()
if mask is not None:
ce_loss = (self.ce(net_output, target_regions) * mask).sum() / torch.clip(mask.sum(), min=1e-8)
else:
Expand Down
6 changes: 3 additions & 3 deletions nnunetv2/training/loss/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,13 +142,13 @@ def get_tp_fp_fn_tn(net_output, gt, axes=None, mask=None, square=False):
# if this is the case then gt is probably already a one hot encoding
y_onehot = gt
else:
y_onehot = torch.zeros(net_output.shape, device=net_output.device)
y_onehot = torch.zeros(net_output.shape, device=net_output.device, dtype=torch.bool)
y_onehot.scatter_(1, gt.long(), 1)

tp = net_output * y_onehot
fp = net_output * (1 - y_onehot)
fp = net_output * (~y_onehot)
fn = (1 - net_output) * y_onehot
tn = (1 - net_output) * (1 - y_onehot)
tn = (1 - net_output) * (~y_onehot)

if mask is not None:
with torch.no_grad():
Expand Down
15 changes: 14 additions & 1 deletion nnunetv2/training/nnUNetTrainer/nnUNetTrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,16 @@ def initialize(self):

def _do_i_compile(self):
# new default: compile is enabled!

# CPU compile crashes for 2D models. Not sure if we even want to support CPU compile!? Better disable
if self.device == torch.device('cpu'):
return False

# default torch.compile doesn't work on windows because there are apparently no triton wheels for it
# https://discuss.pytorch.org/t/windows-support-timeline-for-torch-compile/182268/2
if os.name == 'nt':
return False

if 'nnUNet_compile' not in os.environ.keys():
return True
else:
Expand Down Expand Up @@ -1057,7 +1067,10 @@ def validation_step(self, batch: dict) -> dict:
# CAREFUL that you don't rely on target after this line!
target[target == self.label_manager.ignore_label] = 0
else:
mask = 1 - target[:, -1:]
if target.dtype == torch.bool:
mask = ~target[:, -1:]
else:
mask = 1 - target[:, -1:]
# CAREFUL that you don't rely on target after this line!
target = target[:, :-1]
else:
Expand Down

0 comments on commit 62fe253

Please sign in to comment.