Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated main file #4

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,7 @@ output_patch_temp/
output_test_*/
weights/
fiji/
*.csv
*.csv
data/
__pycache__
output_patch_jpg/
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ python main.py --pilot=0
Installation option 2:
first to install a PyImageJ environment with all packages satisfied
conda install mamba -n base -c conda-forge
mamba create -n pyimagej -c conda-forge pyimagej openjdk=8
mamba create -n pyimagej -c conda-forge pyimagej openjdk=11
conda activate pyimagej
Then, inside this environment, install dependencies for [PyTorch](https://pytorch.org/) and a bunch of other image processing packages such as sciki-image.

Expand Down
26 changes: 26 additions & 0 deletions convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import os
from PIL import Image

yourpath = os.getcwd() + '/output_patch_temp/'
outpath = os.getcwd() + '/output_patch_jpg/'

def convert_to_jpeg(img_name):
yourpath = os.getcwd() + '/output_patch_temp/'
outpath = os.getcwd() + '/output_patch_jpg/' + img_name

for root, dirs, files in os.walk(yourpath, topdown=False):
for name in files:
#print(os.path.join(outpath, name))
if os.path.splitext(os.path.join(root, name))[1].lower() == ".tiff":
if os.path.isfile(os.path.splitext(os.path.join(outpath, name))[0] + ".jpg"):
print(f"A jpeg file already exists for {name}")
# If a jpeg is *NOT* present, create one from the tiff.
else:
outfile = os.path.splitext(os.path.join(outpath, name))[0] + ".jpg"
try:
im = Image.open(os.path.join(root, name))
print("Generating jpeg for %s" % name)
im.thumbnail(im.size)
im.save(outfile, "JPEG", quality=100)
except(Exception, e):
print(e)
7 changes: 3 additions & 4 deletions env.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@ channels:
- conda-forge
- defaults
dependencies:
- python=3.7.6
- cudatoolkit=10.2
- cudatoolkit
- numpy
- pillow
- pip
- pytorch=1.8.0
- pytorch
- torchvision
- matplotlib
- pyyaml
Expand All @@ -19,7 +18,7 @@ dependencies:
- tqdm
- imglyb
- jpype1
- openjdk=8
- openjdk=11
- pyimagej
- scyjava
- xarray
Expand Down
31 changes: 27 additions & 4 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from skimage import io, img_as_ubyte, morphology, img_as_bool, img_as_float, exposure, color
from skimage.util.shape import view_as_windows
from skimage.util import crop, pad
from skimage.util import crop
from skimage.transform import resize, rescale
from PIL import Image
import imagej


import numpy as np
import os, glob, sys
from collections import OrderedDict
Expand All @@ -24,6 +23,28 @@

import model_SHG as md

def convert_to_jpeg(img_name):
yourpath = os.getcwd() + '/output_patch_temp/'
outpath = os.getcwd() + '/output_patch_jpg/' + img_name
os.makedirs(f'output_patch_jpg/{img_name}/')
print("Generating jpeg for %s" % img_name)

for root, dirs, files in os.walk(yourpath, topdown=False):
for name in files:
#print(os.path.join(outpath, name))
if os.path.splitext(os.path.join(root, name))[1].lower() == ".tiff":
if os.path.isfile(os.path.splitext(os.path.join(outpath, name))[0] + ".jpg"):
print(f"A jpeg file already exists for {name}")
# If a jpeg is *NOT* present, create one from the tiff.
else:
outfile = os.path.splitext(os.path.join(outpath, name))[0] + ".jpg"
try:
im = Image.open(os.path.join(root, name))
im.thumbnail(im.size)
im.save(outfile, "JPEG", quality=100)
except(Exception, e):
print(e)

def generate_csv(img_dir, csv_dir):
file_list= [name for name in os.listdir(img_dir) if
os.path.isfile(os.path.join(img_dir, name))]
Expand Down Expand Up @@ -104,7 +125,8 @@ def demo(args):
model.to(device)

print('loading ImageJ, please wait')
ij = imagej.init('fiji/Fiji.app/')
# ij = imagej.init('fiji/Fiji.app')
ij = imagej.init('sc.fiji:fiji:2.1.1')

# use for SHG
TASK = args.input_folder
Expand Down Expand Up @@ -138,7 +160,7 @@ def demo(args):
canvas_1 = int(window_shape[1] * shape_1_factor)
pad_0 = canvas_0 - img.shape[0]
pad_1 = canvas_1 - img.shape[1]
canvas = pad(img, ((0, pad_0), (0, pad_1), (0, 0)), mode='reflect')
canvas = np.pad(img, ((0, pad_0), (0, pad_1), (0, 0)), mode='reflect')
windows = view_as_windows(canvas, window_shape, step_size)
with open(OUTPUT_PATCH_DIR+'TileConfiguration.txt', 'w') as text_file:
print('dim = {}'.format(2), file=text_file)
Expand Down Expand Up @@ -184,6 +206,7 @@ def demo(args):
c1 = exposure.rescale_intensity(c1, in_range=(0, 255), out_range=(0, 1))
print(str(k+1)+"/" + str(len(files)) + " output saved as: " + output_name)
io.imsave(output_name, img_as_ubyte(c1))
convert_to_jpeg(fn)
if args.pilot:
break

Expand Down
6 changes: 6 additions & 0 deletions model_SHG.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@
import torch.nn.init as init
import torch.nn.functional as F
import math
from torch.utils.data import Dataset, DataLoader
from torchvision import utils
import torch.functional as F
import torch

import model_SHG as md

class SkipBlock(nn.Module):
def __init__(self, in_features, out_features):
Expand Down
17 changes: 10 additions & 7 deletions packages.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
matplotlib==3.1.2
numpy==1.17.4
pandas==0.25.3
Pillow==5.3.0
scikit-image==0.16.2
tqdm==4.42.0
pytorch==1.3.1
matplotlib
numpy
pandas
Pillow
scikit-image
tqdm
torch
imagej
torchvision
pyimagej
127 changes: 127 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
import math
from skimage import io, img_as_ubyte, morphology, img_as_bool, img_as_float, exposure, color
from skimage.util.shape import view_as_windows
from skimage.util import crop
from skimage.transform import resize, rescale
from PIL import Image
import imagej

import numpy as np
import os, glob, sys
from collections import OrderedDict
import shutil
import argparse
import pandas as pd
import csv
import warnings
warnings.simplefilter("ignore", UserWarning)


from torch.utils.data import Dataset, DataLoader
from torchvision import utils
import torch.functional as F
import torch

import model_SHG as md

"""Fazer Augmentation e arrumar funções de Loss"""
model = md.GeneratorUNet(3,1)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)


num_epochs = 250
best_loss = float('inf')
best_loss_ce = float('inf')
best_loss_dice = float('inf')
best_iou_score = float('inf')
best_accuracy = float('inf')

patience = 25
trigger_times = 0
accumulation_steps = 4


for epoch in range(num_epochs):
model.train()
train_loss, train_ce_loss, train_dice_loss = 0.0, 0.0, 0.0
train_iou_score, train_accuracy = 0.0, 0.0
# contador de iteracao
i_loop = 0
for images, masks in train_loader:
images, masks = images.to(device), masks.to(device)

outputs = model(images)
# outputs = F.interpolate(outputs, size=(256, 256), mode='bilinear', align_corners=False)
loss, ce_loss, dice_loss = criterion(outputs, masks)
loss.backward()

iou_score = m_iou(outputs, masks)
accuracy = pixel_accuracy(outputs, masks)

if (i_loop + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()

i_loop +=1
train_loss += loss.item() * images.size(0)
train_ce_loss += ce_loss.item() * images.size(0)
train_dice_loss += dice_loss.item() * images.size(0)
train_iou_score += iou_score * images.size(0)
train_accuracy += accuracy * images.size(0)

train_loss /= len(train_loader.dataset)
train_ce_loss /= len(train_loader.dataset)
train_dice_loss /= len(train_loader.dataset)
train_iou_score /= len(train_loader.dataset)
train_accuracy /= len(train_loader.dataset)

model.eval()
valid_loss, valid_ce_loss, valid_dice_loss = 0.0, 0.0, 0.0
valid_iou_score, valid_accuracy = 0.0, 0.0
with torch.no_grad():
for images, masks in valid_loader:
images, masks = images.to(device), masks.to(device)
outputs = model(images)
# outputs = F.interpolate(outputs, size=(256, 256), mode='bilinear', align_corners=False)

loss, ce_loss, dice_loss = criterion(outputs, masks)
iou_score = m_iou(outputs, masks)
accuracy = pixel_accuracy(outputs, masks)


valid_loss += loss.item() * images.size(0)
valid_ce_loss += ce_loss.item() * images.size(0)
valid_dice_loss += dice_loss.item() * images.size(0)
valid_iou_score += iou_score * images.size(0)
valid_accuracy += accuracy * images.size(0)

valid_loss /= len(valid_loader.dataset)
valid_ce_loss /= len(valid_loader.dataset)
valid_dice_loss /= len(valid_loader.dataset)
valid_iou_score /= len(valid_loader.dataset)
valid_accuracy /= len(valid_loader.dataset)

print(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.3f} (CE: {train_ce_loss:.3f}, Dice: {train_dice_loss:.3f}, IoU: {train_iou_score:.3f}, Pixel Accuracy: {train_accuracy:.3f}). Valid Loss: {valid_loss:.3f} (CE: {valid_ce_loss:.3f}, Dice: {valid_dice_loss:.3f}, IoU: {valid_iou_score:.3f}, Pixel Accuracy: {valid_accuracy:.3f})")

# Early stopping
if valid_loss < best_loss:
best_loss = valid_loss
best_loss_ce = valid_ce_loss
best_loss_dice = valid_dice_loss
best_iou_score = valid_iou_score
best_accuracy = valid_accuracy
torch.save(model.state_dict(), 'best_model.pth')
trigger_times = 0
else:
trigger_times += 1
if trigger_times >= patience:
print('Early stopping!')
break

print("Training complete.")
Loading