Source code for unet.schedulers

import logging
from enum import Enum
from typing import Callable

import tensorflow as tf
import tensorflow.keras.backend as K

logger = logging.getLogger(__name__)


[docs]class SchedulerType(Enum): WARMUP_LINEAR_DECAY = "warmup-linear-decay"
[docs]def get(scheduler:SchedulerType, train_dataset_size:int, learning_rate:float, **hyperparams): if scheduler == SchedulerType.WARMUP_LINEAR_DECAY: batch_size = hyperparams["batch_size"] steps_per_epoch = (train_dataset_size + batch_size - 1) // batch_size total_steps = steps_per_epoch * hyperparams["epochs"] warmup_steps = int(total_steps * hyperparams["warmup_proportion"]) logger.info("Total steps %s, warum steps %s", total_steps, warmup_steps) schedule = WarmupLinearDecaySchedule(warmup_steps, total_steps, learning_rate) return LearningRateScheduler(schedule, steps_per_epoch, verbose=0) else: raise ValueError("Unknown scheduler %s"%scheduler)
[docs]class LearningRateScheduler(tf.keras.callbacks.Callback): # Currently, the optimizers in TF2 don't properly support LR schedulers as callable. # As alternative we have to use a Keras callback which only allows for updating the LR per batch instead per step """Learning rate scheduler. Arguments: schedule: a function that takes an step index as input (integer, indexed from 0) and returns a new learning rate as output (float). verbose: int. 0: quiet, 1: update messages. """ def __init__(self, schedule:Callable[[int], float], steps_per_epoch:int, verbose=0): super(LearningRateScheduler, self).__init__() self.schedule = schedule self.steps_per_epoch = steps_per_epoch self.verbose = verbose self._current_step = 0
[docs] def on_train_batch_begin(self, batch, logs=None): new_lr = self.schedule(self._current_step) K.set_value(self.model.optimizer.lr, new_lr) self._current_step += 1 if self.verbose > 0: logger.info('\nBatch %05d: LearningRateScheduler changing learning rate to %s.', batch + 1, new_lr)
[docs] def on_epoch_end(self, epoch, logs=None): logs = logs or {} logs['learning_rate'] = K.get_value(self.model.optimizer.lr)
[docs] def on_train_batch_end(self, batch, logs=None): logs = logs or {} logs['learning_rate'] = K.get_value(self.model.optimizer.lr)
[docs]class WarmupLinearDecaySchedule: """ Linear warmup and then linear decay. Linearly increases learning rate from 0 to 1 over `warmup_steps` training steps. Linearly decreases learning rate from 1. to 0. over remaining `t_total - warmup_steps` steps. """ def __init__(self, warmup_steps, total_steps, learning_rate, min_lr=0.0): self.warmup_steps = warmup_steps self.total_steps = total_steps self.initial_learning_rate = learning_rate self.min_lr = min_lr self.decay_steps = max(1.0, self.total_steps - self.warmup_steps) def __call__(self, step): if step < self.warmup_steps: learning_rate = self.initial_learning_rate * float(step) / max(1., self.warmup_steps) else: decay_factor = max(0, (self.total_steps - step) / self.decay_steps) learning_rate = self.min_lr + (self.initial_learning_rate - self.min_lr) * decay_factor return learning_rate