Source code for unet.utils

from typing import Tuple

import numpy as np


[docs]def crop_to_shape(data, shape: Tuple[int, int, int]): """ Crops the array to the given image shape by removing the border :param data: the array to crop, expects a tensor of shape [batches, nx, ny, channels] :param shape: the target shape [batches, nx, ny, channels] """ diff_nx = (data.shape[0] - shape[0]) diff_ny = (data.shape[1] - shape[1]) if diff_nx == 0 and diff_ny == 0: return data offset_nx_left = diff_nx // 2 offset_nx_right = diff_nx - offset_nx_left offset_ny_left = diff_ny // 2 offset_ny_right = diff_ny - offset_ny_left cropped = data[offset_nx_left:(-offset_nx_right), offset_ny_left:(-offset_ny_right)] assert cropped.shape[0] == shape[0] assert cropped.shape[1] == shape[1] return cropped
[docs]def crop_labels_to_shape(shape: Tuple[int, int, int]): def crop(image, label): return image, crop_to_shape(label, shape) return crop
[docs]def crop_image_and_label_to_shape(shape: Tuple[int, int, int]): def crop(image, label): return crop_to_shape(image, shape), \ crop_to_shape(label, shape) return crop
[docs]def to_rgb(img: np.array): """ Converts the given array into a RGB image and normalizes the values to [0, 1). If the number of channels is less than 3, the array is tiled such that it has 3 channels. If the number of channels is greater than 3, only the first 3 channels are used :param img: the array to convert [bs, nx, ny, channels] :returns img: the rgb image [bs, nx, ny, 3] """ img = img.astype(np.float32) img = np.atleast_3d(img) channels = img.shape[-1] if channels == 1: img = np.tile(img, 3) elif channels == 2: img = np.concatenate((img, img[..., :1]), axis=-1) elif channels > 3: img = img[..., :3] img[np.isnan(img)] = 0 img -= np.amin(img) if np.amax(img) != 0: img /= np.amax(img) return img