Source code for unet.unet

from typing import Optional, Union, Callable, List

import numpy as np
import tensorflow as tf
from tensorflow.keras import Model, Input
from tensorflow.keras import layers
from tensorflow.keras import losses
from tensorflow.keras.initializers import TruncatedNormal
from tensorflow.keras.optimizers import Adam

import unet.metrics


[docs]class ConvBlock(layers.Layer): def __init__(self, layer_idx, filters_root, kernel_size, dropout_rate, padding, activation, **kwargs): super(ConvBlock, self).__init__(**kwargs) self.layer_idx=layer_idx self.filters_root=filters_root self.kernel_size=kernel_size self.dropout_rate=dropout_rate self.padding=padding self.activation=activation filters = _get_filter_count(layer_idx, filters_root) self.conv2d_1 = layers.Conv2D(filters=filters, kernel_size=(kernel_size, kernel_size), kernel_initializer=_get_kernel_initializer(filters, kernel_size), strides=1, padding=padding) self.dropout_1 = layers.Dropout(rate=dropout_rate) self.activation_1 = layers.Activation(activation) self.conv2d_2 = layers.Conv2D(filters=filters, kernel_size=(kernel_size, kernel_size), kernel_initializer=_get_kernel_initializer(filters, kernel_size), strides=1, padding=padding) self.dropout_2 = layers.Dropout(rate=dropout_rate) self.activation_2 = layers.Activation(activation)
[docs] def call(self, inputs, training=None, **kwargs): x = inputs x = self.conv2d_1(x) if training: x = self.dropout_1(x) x = self.activation_1(x) x = self.conv2d_2(x) if training: x = self.dropout_2(x) x = self.activation_2(x) return x
[docs] def get_config(self): return dict(layer_idx=self.layer_idx, filters_root=self.filters_root, kernel_size=self.kernel_size, dropout_rate=self.dropout_rate, padding=self.padding, activation=self.activation, **super(ConvBlock, self).get_config(), )
[docs]class UpconvBlock(layers.Layer): def __init__(self, layer_idx, filters_root, kernel_size, pool_size, padding, activation, **kwargs): super(UpconvBlock, self).__init__(**kwargs) self.layer_idx=layer_idx self.filters_root=filters_root self.kernel_size=kernel_size self.pool_size=pool_size self.padding=padding self.activation=activation filters = _get_filter_count(layer_idx + 1, filters_root) self.upconv = layers.Conv2DTranspose(filters // 2, kernel_size=(pool_size, pool_size), kernel_initializer=_get_kernel_initializer(filters, kernel_size), strides=pool_size, padding=padding) self.activation_1 = layers.Activation(activation)
[docs] def call(self, inputs, **kwargs): x = inputs x = self.upconv(x) x = self.activation_1(x) return x
[docs] def get_config(self): return dict(layer_idx=self.layer_idx, filters_root=self.filters_root, kernel_size=self.kernel_size, pool_size=self.pool_size, padding=self.padding, activation=self.activation, **super(UpconvBlock, self).get_config(), )
[docs]class CropConcatBlock(layers.Layer):
[docs] def call(self, x, down_layer, **kwargs): x1_shape = tf.shape(down_layer) x2_shape = tf.shape(x) height_diff = (x1_shape[1] - x2_shape[1]) // 2 width_diff = (x1_shape[2] - x2_shape[2]) // 2 down_layer_cropped = down_layer[:, height_diff: (x2_shape[1] + height_diff), width_diff: (x2_shape[2] + width_diff), :] x = tf.concat([down_layer_cropped, x], axis=-1) return x
[docs]def build_model(nx: Optional[int] = None, ny: Optional[int] = None, channels: int = 1, num_classes: int = 2, layer_depth: int = 5, filters_root: int = 64, kernel_size: int = 3, pool_size: int = 2, dropout_rate: int = 0.5, padding:str="valid", activation:Union[str, Callable]="relu") -> Model: """ Constructs a U-Net model :param nx: (Optional) image size on x-axis :param ny: (Optional) image size on y-axis :param channels: number of channels of the input tensors :param num_classes: number of classes :param layer_depth: total depth of unet :param filters_root: number of filters in top unet layer :param kernel_size: size of convolutional layers :param pool_size: size of maxplool layers :param dropout_rate: rate of dropout :param padding: padding to be used in convolutions :param activation: activation to be used :return: A TF Keras model """ inputs = Input(shape=(nx, ny, channels), name="inputs") x = inputs contracting_layers = {} conv_params = dict(filters_root=filters_root, kernel_size=kernel_size, dropout_rate=dropout_rate, padding=padding, activation=activation) for layer_idx in range(0, layer_depth - 1): x = ConvBlock(layer_idx, **conv_params)(x) contracting_layers[layer_idx] = x x = layers.MaxPooling2D((pool_size, pool_size))(x) x = ConvBlock(layer_idx + 1, **conv_params)(x) for layer_idx in range(layer_idx, -1, -1): x = UpconvBlock(layer_idx, filters_root, kernel_size, pool_size, padding, activation)(x) x = CropConcatBlock()(x, contracting_layers[layer_idx]) x = ConvBlock(layer_idx, **conv_params)(x) x = layers.Conv2D(filters=num_classes, kernel_size=(1, 1), kernel_initializer=_get_kernel_initializer(filters_root, kernel_size), strides=1, padding=padding)(x) x = layers.Activation(activation)(x) outputs = layers.Activation("softmax", name="outputs")(x) model = Model(inputs, outputs, name="unet") return model
def _get_filter_count(layer_idx, filters_root): return 2 ** layer_idx * filters_root def _get_kernel_initializer(filters, kernel_size): stddev = np.sqrt(2 / (kernel_size ** 2 * filters)) return TruncatedNormal(stddev=stddev)
[docs]def finalize_model(model: Model, loss: Optional[Union[Callable, str]]=losses.categorical_crossentropy, optimizer: Optional= None, metrics:Optional[List[Union[Callable,str]]]=None, dice_coefficient: bool=True, auc: bool=True, mean_iou: bool=True, **opt_kwargs): """ Configures the model for training by setting, loss, optimzer, and tracked metrics :param model: the model to compile :param loss: the loss to be optimized. Defaults to `categorical_crossentropy` :param optimizer: the optimizer to use. Defaults to `Adam` :param metrics: List of metrics to track. Is extended by `crossentropy` and `accuracy` :param dice_coefficient: Flag if the dice coefficient metric should be tracked :param auc: Flag if the area under the curve metric should be tracked :param mean_iou: Flag if the mean over intersection over union metric should be tracked :param opt_kwargs: key word arguments passed to default optimizer (Adam), e.g. learning rate """ if optimizer is None: optimizer = Adam(**opt_kwargs) if metrics is None: metrics = ['categorical_crossentropy', 'categorical_accuracy', ] if mean_iou: metrics += [unet.metrics.mean_iou] if dice_coefficient: metrics += [unet.metrics.dice_coefficient] if auc: metrics += [tf.keras.metrics.AUC()] model.compile(loss=loss, optimizer=optimizer, metrics=metrics, )