from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, UpSampling2D, concatenate from tensorflow.keras.models import Model import tensorflow_probability as tfp from groupnorm import GroupNormalization from utils import normal_prior def down_stage(inputs, filters, kernel_size=3, activation="relu", padding="SAME"): conv = Conv2D(filters, kernel_size, activation=activation, padding=padding)(inputs) conv = GroupNormalization()(conv) conv = Conv2D(filters, kernel_size, activation=activation, padding=padding)(conv) conv = GroupNormalization()(conv) pool = MaxPooling2D()(conv) return conv, pool def up_stage(inputs, skip, filters, prior_fn, kernel_size=3, activation="relu", padding="SAME"): up = UpSampling2D()(inputs) up = tfp.layers.Convolution2DFlipout(filters, 2, activation=activation, padding=padding, kernel_prior_fn=prior_fn)(up) up = GroupNormalization()(up) merge = concatenate([skip, up]) merge = GroupNormalization()(merge) conv = tfp.layers.Convolution2DFlipout(filters, kernel_size, activation=activation, padding=padding, kernel_prior_fn=prior_fn)(merge) conv = GroupNormalization()(conv) conv = tfp.layers.Convolution2DFlipout(filters, kernel_size, activation=activation, padding=padding, kernel_prior_fn=prior_fn)(conv) conv = GroupNormalization()(conv) return conv def end_stage(inputs, prior_fn, kernel_size=3, activation="relu", padding="SAME"): conv = tfp.layers.Convolution2DFlipout(1, kernel_size, activation=activation, padding="SAME", kernel_prior_fn=prior_fn)(inputs) conv = tfp.layers.Convolution2DFlipout(1, 1, activation="sigmoid", kernel_prior_fn=prior_fn)(conv) return conv def bayesian_unet(input_shape=(280, 280, 1), kernel_size=3, activation="relu", padding="SAME", **kwargs): prior_std = kwargs.get("prior_std", 1) prior_fn = normal_prior(prior_std) inputs = Input(input_shape) conv1, pool1 = down_stage(inputs, 16, kernel_size=kernel_size, activation=activation, padding=padding) conv2, pool2 = down_stage(pool1, 32, kernel_size=kernel_size, activation=activation, padding=padding) conv3, pool3 = down_stage(pool2, 64, kernel_size=kernel_size, activation=activation, padding=padding) conv4, _ = down_stage(pool3, 128, kernel_size=kernel_size, activation=activation, padding=padding) conv5 = up_stage(conv4, conv3, 64, prior_fn, kernel_size=kernel_size, activation=activation, padding=padding) conv6 = up_stage(conv5, conv2, 32, prior_fn, kernel_size=kernel_size, activation=activation, padding=padding) conv7 = up_stage(conv6, conv1, 16, prior_fn, kernel_size=kernel_size, activation=activation, padding=padding) conv8 = end_stage(conv7, prior_fn, kernel_size=kernel_size, activation=activation, padding=padding) return Model(inputs=inputs, outputs=conv8)