U-NETS For Dummies (PyTorch & TensorFlow) | by Ryan Chew | Sep, 2024


Everything you need to know (and a lot more you don’t need to know) about U-NETs. A deep dive into what a UNET is, why we use UNETs, the math that goes into a UNET, and a code snippet of a UNET for both PyTorch and TensorFlow.

Image Source: https://smilegate.ai/en/2021/07/05/tensorflow-vs-pytorch/

· What is A U-NET?
· Why U-NET?
· U-NET Drawbacks
· U-NET Architecture
· What Does Data Look Like For U-NETs?
· Exploring Loss Functions of U-NETs
· PyTorch U-NET Code
· TensorFlow U-NET Code

UNET is the model architecture introduced in 2015. It was published in a paper that demonstrated its powerful capabilities within biomedical image segmentation. It receives its name from the shape of the architecture that resembles the letter U.

Now, many segmentation tasks don’t require a UNET, for instance, fully convolutional networks (FCN) and Deeplab are often used in semantic segmentation tasks with great results. The major benefit of using a UNET is the amount of detail gained by encoding and decoding the image, making it the go-to mode in medical imaging tasks which often don’t have a clear distinct subject. Specifically, the contracting path extracts features and captures context by increasing depth before the expanding path returns the image to the original shape for precision. The image below shows a heart CT scan that contains a calcium deposit. Instead of having a clear subject to segment, the entire image looks almost uniform making it incredibly hard to identify the calcium deposit.

Image Source: https://stanfordaimi.azurewebsites.net/datasets/e8ca74dc-8dd4-4340-815a-60b41f6cb2aa

Just because the U-NET architecture can work, doesn’t mean it is always the way to go. One of the most major drawbacks is the computational resources needed. As will be seen below, each block requires multiple convolutional blocks meaning many parameters are needed for a model. Without an abundance of resources, it is often difficult to train large U-NET models.

Image Source: https://arxiv.org/pdf/1505.04597

Before we look at the U-NET architecture as a whole it is important to understand the components that make up the U-NET. The contracting path is made up of blocks that each capture the context while reducing spatial dimensions and increasing its depth. Each block consists of convolutional layers followed by a pooling layer. How does this work?

Single Block of U-NET

The image above shows one block. We start with an image with dimensions 1x572x572 representing a single color channel and the image that is 572 by 572 pixels. Next, the convolutional layer increases the depth from 1 channel to 64. Notice the size of the image reduced from 572×572 to 570×570. Let’s take a look at a 2d convolution. In the image, a 3×3 kernel is used, this means that in order to make a 3×3 square with the 4 corners you have to start on the second square as shown below (left). If you attempt to center a 3×3 in the corner or on an edge, part of the 3×3 will not be within the image as shown below(right).

2D-Convolution with 3×3 kernel

However, there is a way to prevent this. Using padding after the convolutional layer will keep the image size constant until the pooling at the end of the block.

Next, we take a look at the process of reducing the spatial dimensions. Using pooling we can half the size of each side with a 2×2 kernel. The most commonly used is max pooling where only the largest value from the four boxes in the 2×2 is kept.

2D Max Pooling 2×2 Kernel

Finally, we repeat that block continuing to reduce the spatial dimension and increase depth until we reach our bottleneck. After that, we enter the expanding path where we mirror the expanding path but perform tranposed convolutions to return the image to the original dimensions. On top of that, we crop and copy over the information from the corresponding block of the expanding path to retain the information captured. Let’s take a deeper dive:

Notice how the two back-to-back 2d convolutions are identical to the ones in the contracting path but instead of increasing depth, it decreases depth by a factor of two. The image above shows the convolution taking a 512x104x104 image to 256x102x102 halving the depth or number of channels. Then the green arrow which represents a transposed convolution doubles the spatial dimension. How does this work?

Transposed Convolution

In transposed convolutions, each value is multiplied by a kernel with dimensions n by m where n and m are integers greater than one. Each value in our input image is multiplied by each value in the kernel. We notice this is directly opposite of our regular convolutions, where previously we slide the window over our initial image, now each value of our initial image is multiplied into the kernel. This means that the final image will have increased in size.

But what is the point in the contracting path if the expanding path just reverses the transformations?

After each block of the contracting path, a center crop is performed on the resulting image. This cropped image is then combined with the corresponding output of the expanding path. This is done to preserve the information gained from the contracting path while ensuring the output keeps the original image dimensions.

For this discussion, we will be talking specifically about binary image segmentation. Oftentimes, data will be presented as PNG(.png) images or Dicom (.dcm) files for CT scans. We know that to train a model, we must have these images as well as a labeled ground truth. The ground truth will be an image of the same dimensions as the original image, with binary labels for every pixel. Say you are segmenting a heart nodule within an image of a heart CT scan. The pixels that are a part of that heart nodule will be represented with a number 1 in the mask. On the other hand, any pixels that are not a part of the nodule are labeled with a 0. This means that the output of a U-NET will also be represented by an array of floating point numbers predicting either 0 or 1.

The first loss function we explore is amongst the most common for medical applications, the dice coefficient. It is defined as 2 times the intersection over the sum of the areas from the prediction and ground truth.

Dice Coefficient

The next loss function is IoU. IoU stands for intersection over union and does exactly as it describes. It takes the intersection between the ground truth and the prediction and divides by the two’s union.

Intersection Over Union

Lastly, we take a look at binary cross entropy. Binary cross entropy, also known as log loss iterates over each pixel in the image. It is the average of the probabilities of each pixel. It shows good performance in models with smaller segmentations.

Binary Cross Entropy (Log Loss)

Let’s start with our dependencies:

import torch
from torch import nn
from torch.nn import Conv2d, MaxPool2d, ReLU, ConvTranspose2d
import torchvision
from torchvision.transforms import CenterCrop

We start by creating a single block that contains our double convolution. We will also use Xavier initialization to randomize our initial weights. We noted before that padding of one helps retain the original image dimension. As we can see below, we must input the number of input channels and output channels. On the contracting path, we will be increasing the number of channels starting with 1 channel in grayscale or 3 channels in RGB.

class Block(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()

self.conv1 = Conv2d(in_ch, out_ch, 3, padding=1)
self.relu = ReLU()
self.conv2 = Conv2d(out_ch, out_ch, 3, padding=1)

# init random weights
nn.init.xavier_normal_(self.conv1.weight)
nn.init.zeros_(self.conv1.bias)

nn.init.xavier_normal_(self.conv2.weight)
nn.init.zeros_(self.conv2.bias)

def forward(self, x):
return self.relu(self.conv2(self.relu(self.conv1(x))))

Next, we will make use of these convolutional blocks by adding them to our encoding path.

We see that the channel input is a set of integers, this represents all the channels that we will transform our image into on the contracting path. We start with one, as our example uses grayscale images. We see that self.encoder iterates through the set of channel numbers and creates a list of blocks with each input and output channel number.

In the forward pass of the autoencoder, we will loop through every block of the autoencoder, saving the result in an array called skip_out. This skip output will later be used in the expanding path, so the value is returned for later use.

After passing the input through the block, we use a 2×2 pool to reduce the spatial dimension.

class AutoEncoder(nn.Module):
def __init__(self, channels = (1,32,64,128,256,512,1024)):
super().__init__()
self.pool = MaxPool2d((2,2))

self.encoder = nn.ModuleList([
Block(channels[i], channels[i+1]) for i in range(len(channels)-1) # 1, 32, ..., 1024
])

def forward(self, x):
skip_out = []
for block in self.encoder: # Goes through all blocks, passes through block and saves skip output
x = block(x)
skip_out.append(x)
x = self.pool(x) # Reduces dim
return skip_out

We now have to create the other side of the U-NET, which is the auto decoder. Let’s take a look at how the channels here differ from the autoencoder.

To simplify things, we decide to take the exact same input. However, instead of moving from a smaller depth to a larger depth, we must go in reverse. Note that self.decoder looks the same as our encoder within our autoencoder. This is because we use the same double convolution on both sides of the U-NET. However, this time, we must also create a list of transposed convolutions as opposed to the pooling we used in the autoencoder.

Our forward pass starts by using the up convolution, which starts by taking the image out of the bottleneck that the autoencoder created, and into the first block.

Then, we must take the output from the corresponding block of the autoencoder and concatenate it to our image. To accomplish this, we take the input enc_out, which is the output from our autoencoder that we defined as the variable skip_out in the autoencoder. The autoencoder output must first be cropped to match our image before concatenation.

Lastly, we run the image through the double convolution before using min-max scaling to keep our values between 0 and 1.

class AutoDecoder(nn.Module):
def __init__(self, channels = (1,32,64,128,256,512,1024)):
super().__init__()
# Reverse of Encoder (Excluding First Unneeded in Output)
self.channels = channels[:0:-1]
self.pool = MaxPool2d((2,2))
self.upconv = nn.ModuleList([
ConvTranspose2d(self.channels[i], self.channels[i+1], 2, 2) for i in range(len(self.channels)-1)
])

self.decoder = nn.ModuleList([
Block(self.channels[i], self.channels[i+1]) for i in range(len(self.channels)-1)
])

def center_crop(self, x, enc_out): # Crop encoder output
_, _, h, w = x.shape
enc_out = CenterCrop([h,w])(enc_out)
return enc_out

def forward(self, x, enc_out:list):
for i in range(len(self.channels)-1):
x = self.upconv[i](x)
enc_ftrs = self.center_crop(x, enc_out[i]) # Crop Skip
x = torch.cat([x, enc_ftrs], dim=1) # Concatenate Decoder and Skip
x = self.decoder[i](x)

# Min Max Scaling [0,1]
x = (x-x.min())/(x.max()-x.min())
return x

Finally, we combine our autoencoder and autodecoder to create our U-NET.

We define our autoencoder and auto-decoder using the same set of channels.

Next, we have to handle what we do with our image after the expanding path. We must include a final step to produce the output segmentation map. For this, we will use a 2d convolution that has an output depth of 1. Again, we randomly initialize our weights with Xavier Initialization.

The rest of the model is just putting the pieces together. We start by running our image through the autoencoder, saving all of the skip connection inputs in a list called skips.

Next, we run the image through the auto-decoder. We must begin with the last output of our autoencoder, which also happens to be the last element of the array skips. The other parameter taken by our auto-decoder is the list of skip connections. This becomes the rest of the list excluding the bottleneck.

Lastly, we run our 2d convolution and use sigmoid to output a floating point number between 0 and 1.

class UNET(nn.Module):
def __init__(self, channels = (1,32,64,128,256,512, 1024)):
super().__init__()

# Encoder Path
self.enc_path = AutoEncoder(channels)

# Decoder Path
self.dec_path = AutoDecoder(channels)

self.out = Conv2d(channels[1], 1, 1)

# init random weights
nn.init.xavier_normal_(self.out.weight)
nn.init.zeros_(self.out.bias)

def forward(self, x):
skips = self.enc_path(x)
x = self.dec_path(skips[::-1][0], skips[::-1][1:])
# Reverse of enc_out = upward path of decoder
# [0] -> 1024 output
# [1:] -> All other skip outputs
x = self.out(x)
x = F.sigmoid(x)

return x

We will follow a similar process for Tensorflow. To begin, we must import our dependencies.

import tensorflow as tf
from tensorflow.keras.layers import Conv2D, ReLU, MaxPooling2D, UpSampling2D, Concatenate, Input
from tensorflow.keras.models import Model

Again, we define a double convolution block.

def Block(x, output):
x = Conv2D(output, 3, padding= "same")(x)
x = ReLU()(x)
x = Conv2D(output, 3, padding= "same")(x)
x = ReLU()(x)
return x

Next, we use these blocks in our U-NET model. This works exactly the same as our PyTorch model. After each block of the contracting path, we use 2×2 MaxPooling to reduce spatial dimension. After every block of the expanding path, a transposed convolution is used to increase spatial dimension.

def UNET(input_shape=(128, 128, 1), num_classes=1):
inputs = Input(input_shape)

# Encoder
down1 = Block(inputs, 64)
pool1 = MaxPooling2D((2, 2))(down1)

down2 = Block(pool1, 128)
pool2 = MaxPooling2D((2, 2))(down2)

down3 = Block(pool2, 256)
pool3 = MaxPooling2D((2, 2))(down3)

down4 = Block(pool3, 512)
pool4 = MaxPooling2D((2, 2))(down4)

# Bottleneck
bottleneck = Block(pool4, 1024)

# Decoder
up4 = UpSampling2D((2, 2))(bottleneck)
up4 = Concatenate()([up4, down4])
up4 = Block(up4, 512)

up3 = UpSampling2D((2, 2))(up4)
up3 = Concatenate()([up3, down3])
up3 = Block(up3, 256)

up2 = UpSampling2D((2, 2))(up3)
up2 = Concatenate()([up2, down2])
up2 = Block(up2, 128)

up1 = UpSampling2D((2, 2))(up2)
up1 = Concatenate()([up1, down1])
up1 = Block(up1, 64)

# Output layer
outputs = Conv2D(num_classes, 1, padding="same", activation="sigmoid")(up1)

# Create the model
model = Model(inputs, outputs)
return model



Source link

Be the first to comment

Leave a Reply

Your email address will not be published.


*