Training settings
Please provide a valid training processor option
Neural network architecture
import os
import tensorflow as tf
import numpy as np
from akida_models import akidanet_imagenet
from keras import Model
from tensorflow.keras.optimizers.legacy import Adam
from tensorflow.keras.layers import BatchNormalization, Conv2D, Softmax, ReLU
from cnn2snn import check_model_compatibility
from ei_tensorflow.constrained_object_detection import models, dataset, metrics, util
WEIGHTS_PREFIX = os.environ.get('WEIGHTS_PREFIX', os.getcwd())
def build_model(input_shape: tuple, alpha: float,
num_classes: int, weight_regularizer=None) -> tf.keras.Model:
""" Construct a constrained object detection model.
Args:
input_shape: Passed to AkidaNet construction.
alpha: AkidaNet alpha value.
num_classes: Number of classes, i.e. final dimension size, in output.
Returns:
Uncompiled keras model.
Model takes (B, H, W, C) input and
returns (B, H//8, W//8, num_classes) logits.
"""
#! Create a quantized base model without top layers
a_base_model = akidanet_imagenet(input_shape=input_shape,
alpha=alpha,
include_top=False,
input_scaling=None)
#! Get pretrained quantized weights and load them into the base model
#! Available base models are:
#! akidanet_imagenet_224.h5 - float32 model, 224x224x3, alpha=1.00
#! akidanet_imagenet_224_alpha_50.h5 - float32 model, 224x224x3, alpha=0.50
#! akidanet_imagenet_224_alpha_25.h5 - float32 model, 224x224x3, alpha=0.25
#! akidanet_imagenet_160.h5 - float32 model, 160x160x3, alpha=1.00
#! akidanet_imagenet_160_alpha_50.h5 - float32 model, 160x160x3, alpha=0.50
#! akidanet_imagenet_160_alpha_25.h5 - float32 model, 160x160x3, alpha=0.25
pretrained_weights = os.path.join(WEIGHTS_PREFIX , 'transfer-learning-weights/akidanet/akidanet_imagenet_224_alpha_50.h5')
a_base_model.load_weights(pretrained_weights, by_name=True, skip_mismatch=True)
a_base_model.trainable = True
#! Default batch norm is configured for huge networks, let's speed it up
for layer in a_base_model.layers:
if type(layer) == BatchNormalization:
layer.momentum = 0.9
#! Cut AkidaNet where it hits 1/8th input resolution; i.e. (HW/8, HW/8, C)
a_cut_point = a_base_model.get_layer('separable_5_relu')
#! Now attach a small additional head on the AkidaNet
a_model_part_head = Conv2D(filters=32, kernel_size=1, strides=1, padding='same',
kernel_regularizer=weight_regularizer)(a_cut_point.output)
a_model_part = ReLU()(a_model_part_head)
a_logits = Conv2D(filters=num_classes, kernel_size=1, strides=1, padding='same',
activation=None, kernel_regularizer=weight_regularizer)(a_model_part)
fomo_akida = Model(inputs=a_base_model.input, outputs=a_logits)
#! Check if the model is sompatbile with Akida (fail quickly before training)
compatible = check_model_compatibility(fomo_akida, input_is_image=True)
if not compatible:
print("Model is not compatible with Akida!")
sys.exit(1)
return fomo_akida
def train(num_classes: int, learning_rate: float, num_epochs: int,
alpha: float, object_weight: int,
train_dataset: tf.data.Dataset,
validation_dataset: tf.data.Dataset,
best_model_path: str,
input_shape: tuple,
callbacks: 'list',
quantize_function,
qat_function,
batch_size: int,
lr_finder: bool = False,
ensure_determinism: bool = False) -> tf.keras.Model:
""" Construct and train a constrained object detection model.
Args:
num_classes: Number of classes in datasets. This does not include
implied background class introduced by segmentation map dataset
conversion.
learning_rate: Learning rate for Adam.
num_epochs: Number of epochs passed to model.fit
alpha: Alpha used to construct AkidaNet. Pretrained weights will be
used if there is a matching set.
object_weight: The weighting to give the object in the loss function
where background has an implied weight of 1.0.
train_dataset: Training dataset of (x, (bbox, one_hot_y))
validation_dataset: Validation dataset of (x, (bbox, one_hot_y))
best_model_path: location to save best model path. note: weights
will be restored from this path based on best val_f1 score.
input_shape: The shape of the model's input
callbacks: List of callbacks
quantize_function: Akida quantize function
qat_function: Akida quantize-aware training function
batch_size: Training batch size
lr_finder: If True, the learning_rate will be replaced with a value
found by the learning rate finder.
ensure_determinism: If true, functions that may be non-
deterministic are disabled (e.g. autotuning prefetch). This
should be true in test environments.
Returns:
Trained keras model.
Constructs a new constrained object detection model with num_classes+1
outputs (denoting the classes with an implied background class of 0).
Both training and validation datasets are adapted from
(x, (bbox, one_hot_y)) to (x, segmentation_map). Model is trained with a
custom weighted cross entropy function.
"""
num_classes_with_background = num_classes + 1
input_width_height = None
width, height, input_num_channels = input_shape
if width != height:
raise Exception(f"Only square inputs are supported; not {input_shape}")
input_width_height = width
model = build_model(input_shape=input_shape,
alpha=alpha,
num_classes=num_classes_with_background,
weight_regularizer=tf.keras.regularizers.l2(4e-5))
#! Derive output size from model
model_output_shape = model.layers[-1].output.shape
_batch, width, height, num_classes = model_output_shape
if width != height:
raise Exception(f"Only square outputs are supported; not {model_output_shape}")
output_width_height = width
#! Build weighted cross entropy loss specific to this model size
weighted_xent = models.construct_weighted_xent_fn(model.output.shape, object_weight)
prefetch_policy = 1 if ensure_determinism else tf.data.experimental.AUTOTUNE
#! Transform bounding box labels into segmentation maps
def as_segmentation(ds, shuffle):
ds = ds.map(dataset.bbox_to_segmentation(output_width_height, num_classes_with_background))
if not ensure_determinism and shuffle:
ds = ds.shuffle(buffer_size=batch_size*4)
ds = ds.batch(batch_size, drop_remainder=False).prefetch(prefetch_policy)
return ds
train_segmentation_dataset = as_segmentation(train_dataset, True)
validation_segmentation_dataset = as_segmentation(validation_dataset, False)
validation_dataset_for_callback = (validation_dataset
.batch(batch_size, drop_remainder=False)
.prefetch(prefetch_policy))
#! Initialise bias of final classifier based on training data prior.
util.set_classifier_biases_from_dataset(
model, train_segmentation_dataset)
if lr_finder:
learning_rate = ei_tensorflow.lr_finder.find_lr(model, train_segmentation_dataset, weighted_xent)
opt = Adam(learning_rate=learning_rate)
model.compile(loss=weighted_xent,
optimizer=opt)
#! Create callback that will do centroid scoring on end of epoch against
#! validation data. Include a callback to show % progress in slow cases.
centroid_callback = metrics.CentroidScoring(validation_dataset_for_callback,
output_width_height, num_classes_with_background)
print_callback = metrics.PrintPercentageTrained(num_epochs)
#! Include a callback for model checkpointing based on the best validation f1.
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(best_model_path,
monitor='val_f1', save_best_only=True, mode='max',
save_weights_only=True, verbose=0)
model.fit(train_segmentation_dataset,
validation_data=validation_segmentation_dataset,
epochs=num_epochs,
callbacks=callbacks + [centroid_callback, print_callback, checkpoint_callback],
verbose=0)
#! Restore best weights.
model.load_weights(best_model_path)
#! Add explicit softmax layer before export.
softmax_layer = Softmax()(model.layers[-1].output)
model = Model(model.input, softmax_layer)
#! Check if model is compatible with Akida
compatible = check_model_compatibility(model, input_is_image=True)
if not compatible:
print("Model is not compatible with Akida!")
sys.exit(1)
#! Quantize model to 4/4/8
akida_model = quantize_function(keras_model=model)
#! Perform quantization-aware training
akida_model = qat_function(akida_model=akida_model,
train_dataset=train_segmentation_dataset,
validation_dataset=validation_segmentation_dataset,
optimizer=opt,
fine_tune_loss=weighted_xent,
fine_tune_metrics=None,
callbacks=callbacks + [centroid_callback, print_callback],
stopping_metric='val_f1',
fit_verbose=0)
return model, akida_model
EPOCHS = args.epochs or 60
LEARNING_RATE = args.learning_rate or 0.001
BATCH_SIZE = args.batch_size or 128
import tensorflow as tf
def akida_quantize_model(
keras_model,
weight_quantization: int = 4,
activ_quantization: int = 4,
input_weight_quantization: int = 8,
):
import cnn2snn
print("Performing post-training quantization...")
akida_model = cnn2snn.quantize(
keras_model,
weight_quantization=weight_quantization,
activ_quantization=activ_quantization,
input_weight_quantization=input_weight_quantization,
)
print("Performing post-training quantization OK")
print("")
return akida_model
def akida_perform_qat(
akida_model,
train_dataset: tf.data.Dataset,
validation_dataset: tf.data.Dataset,
optimizer: str,
fine_tune_loss: str,
fine_tune_metrics: "list[str]",
callbacks,
stopping_metric: str = "val_accuracy",
fit_verbose: int = 2,
qat_epochs: int = 30,
):
early_stopping = tf.keras.callbacks.EarlyStopping(
monitor=stopping_metric,
mode="max",
verbose=1,
min_delta=0,
patience=10,
restore_best_weights=True,
)
callbacks.append(early_stopping)
print("Running quantization-aware training...")
akida_model.compile(
optimizer=optimizer, loss=fine_tune_loss, metrics=fine_tune_metrics
)
akida_model.fit(
train_dataset,
epochs=qat_epochs,
verbose=fit_verbose,
validation_data=validation_dataset,
callbacks=callbacks,
)
print("Running quantization-aware training OK")
print("")
return akida_model
model, akida_model = train(num_classes=classes,
learning_rate=LEARNING_RATE,
num_epochs=EPOCHS,
alpha=0.5,
object_weight=100,
train_dataset=train_dataset,
validation_dataset=validation_dataset,
best_model_path=BEST_MODEL_PATH,
input_shape=MODEL_INPUT_SHAPE,
callbacks=callbacks,
quantize_function=akida_quantize_model,
qat_function=akida_perform_qat,
batch_size=BATCH_SIZE,
lr_finder=False,
ensure_determinism=ensure_determinism)
disable_per_channel_quantization = False
Input layer (150,528 features)
Akida FOMO (Faster Objects, More Objects) AkidaNet (alpha=0.5 @224x224x3)
Output layer (1 classes)
Model
Model version: