Cristian Axenie / Project SPIDER - Team NeurOhm Brainchip Akida Public

Training settings

Please provide a valid number of training cycles (numeric only)
Please provide a valid number for the learning rate (between 0 and 1)
Please provide a valid training processor option

Augmentation settings

Advanced training settings

Neural network architecture

import os import tensorflow as tf import numpy as np from akida_models import akidanet_imagenet from keras import Model from tensorflow.keras.optimizers.legacy import Adam from tensorflow.keras.layers import BatchNormalization, Conv2D, Softmax, ReLU from cnn2snn import check_model_compatibility from ei_tensorflow.constrained_object_detection import models, dataset, metrics, util WEIGHTS_PREFIX = os.environ.get('WEIGHTS_PREFIX', os.getcwd()) def build_model(input_shape: tuple, alpha: float, num_classes: int, weight_regularizer=None) -> tf.keras.Model: """ Construct a constrained object detection model. Args: input_shape: Passed to AkidaNet construction. alpha: AkidaNet alpha value. num_classes: Number of classes, i.e. final dimension size, in output. Returns: Uncompiled keras model. Model takes (B, H, W, C) input and returns (B, H//8, W//8, num_classes) logits. """ #! Create a quantized base model without top layers a_base_model = akidanet_imagenet(input_shape=input_shape, alpha=alpha, include_top=False, input_scaling=None) #! Get pretrained quantized weights and load them into the base model #! Available base models are: #! akidanet_imagenet_224.h5 - float32 model, 224x224x3, alpha=1.00 #! akidanet_imagenet_224_alpha_50.h5 - float32 model, 224x224x3, alpha=0.50 #! akidanet_imagenet_224_alpha_25.h5 - float32 model, 224x224x3, alpha=0.25 #! akidanet_imagenet_160.h5 - float32 model, 160x160x3, alpha=1.00 #! akidanet_imagenet_160_alpha_50.h5 - float32 model, 160x160x3, alpha=0.50 #! akidanet_imagenet_160_alpha_25.h5 - float32 model, 160x160x3, alpha=0.25 pretrained_weights = os.path.join(WEIGHTS_PREFIX , 'transfer-learning-weights/akidanet/akidanet_imagenet_224_alpha_50.h5') a_base_model.load_weights(pretrained_weights, by_name=True, skip_mismatch=True) a_base_model.trainable = True #! Default batch norm is configured for huge networks, let's speed it up for layer in a_base_model.layers: if type(layer) == BatchNormalization: layer.momentum = 0.9 #! Cut AkidaNet where it hits 1/8th input resolution; i.e. (HW/8, HW/8, C) a_cut_point = a_base_model.get_layer('separable_5_relu') #! Now attach a small additional head on the AkidaNet a_model_part_head = Conv2D(filters=32, kernel_size=1, strides=1, padding='same', kernel_regularizer=weight_regularizer)(a_cut_point.output) a_model_part = ReLU()(a_model_part_head) a_logits = Conv2D(filters=num_classes, kernel_size=1, strides=1, padding='same', activation=None, kernel_regularizer=weight_regularizer)(a_model_part) fomo_akida = Model(inputs=a_base_model.input, outputs=a_logits) #! Check if the model is sompatbile with Akida (fail quickly before training) compatible = check_model_compatibility(fomo_akida, input_is_image=True) if not compatible: print("Model is not compatible with Akida!") sys.exit(1) return fomo_akida def train(num_classes: int, learning_rate: float, num_epochs: int, alpha: float, object_weight: int, train_dataset: tf.data.Dataset, validation_dataset: tf.data.Dataset, best_model_path: str, input_shape: tuple, callbacks: 'list', quantize_function, qat_function, batch_size: int, ensure_determinism: bool = False) -> tf.keras.Model: """ Construct and train a constrained object detection model. Args: num_classes: Number of classes in datasets. This does not include implied background class introduced by segmentation map dataset conversion. learning_rate: Learning rate for Adam. num_epochs: Number of epochs passed to model.fit alpha: Alpha used to construct AkidaNet. Pretrained weights will be used if there is a matching set. object_weight: The weighting to give the object in the loss function where background has an implied weight of 1.0. train_dataset: Training dataset of (x, (bbox, one_hot_y)) validation_dataset: Validation dataset of (x, (bbox, one_hot_y)) best_model_path: location to save best model path. note: weights will be restored from this path based on best val_f1 score. input_shape: The shape of the model's input callbacks: List of callbacks quantize_function: Akida quantize function qat_function: Akida quantize-aware training function batch_size: Training batch size ensure_determinism: If true, functions that may be non- deterministic are disabled (e.g. autotuning prefetch). This should be true in test environments. Returns: Trained keras model. Constructs a new constrained object detection model with num_classes+1 outputs (denoting the classes with an implied background class of 0). Both training and validation datasets are adapted from (x, (bbox, one_hot_y)) to (x, segmentation_map). Model is trained with a custom weighted cross entropy function. """ num_classes_with_background = num_classes + 1 input_width_height = None width, height, input_num_channels = input_shape if width != height: raise Exception(f"Only square inputs are supported; not {input_shape}") input_width_height = width model = build_model(input_shape=input_shape, alpha=alpha, num_classes=num_classes_with_background, weight_regularizer=tf.keras.regularizers.l2(4e-5)) #! Derive output size from model model_output_shape = model.layers[-1].output.shape _batch, width, height, num_classes = model_output_shape if width != height: raise Exception(f"Only square outputs are supported; not {model_output_shape}") output_width_height = width #! Build weighted cross entropy loss specific to this model size weighted_xent = models.construct_weighted_xent_fn(model.output.shape, object_weight) prefetch_policy = 1 if ensure_determinism else tf.data.experimental.AUTOTUNE #! Transform bounding box labels into segmentation maps def as_segmentation(ds, shuffle): ds = ds.map(dataset.bbox_to_segmentation(output_width_height, num_classes_with_background)) if not ensure_determinism and shuffle: ds = ds.shuffle(buffer_size=batch_size*4) ds = ds.batch(batch_size, drop_remainder=False).prefetch(prefetch_policy) return ds train_segmentation_dataset = as_segmentation(train_dataset, True) validation_segmentation_dataset = as_segmentation(validation_dataset, False) validation_dataset_for_callback = (validation_dataset .batch(batch_size, drop_remainder=False) .prefetch(prefetch_policy)) #! Initialise bias of final classifier based on training data prior. util.set_classifier_biases_from_dataset( model, train_segmentation_dataset) opt = Adam(learning_rate=learning_rate) model.compile(loss=weighted_xent, optimizer=opt) #! Create callback that will do centroid scoring on end of epoch against #! validation data. Include a callback to show % progress in slow cases. centroid_callback = metrics.CentroidScoring(validation_dataset_for_callback, output_width_height, num_classes_with_background) print_callback = metrics.PrintPercentageTrained(num_epochs) #! Include a callback for model checkpointing based on the best validation f1. checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(best_model_path, monitor='val_f1', save_best_only=True, mode='max', save_weights_only=True, verbose=0) model.fit(train_segmentation_dataset, validation_data=validation_segmentation_dataset, epochs=num_epochs, callbacks=callbacks + [centroid_callback, print_callback, checkpoint_callback], verbose=0) #! Restore best weights. model.load_weights(best_model_path) #! Add explicit softmax layer before export. softmax_layer = Softmax()(model.layers[-1].output) model = Model(model.input, softmax_layer) #! Check if model is compatible with Akida compatible = check_model_compatibility(model, input_is_image=True) if not compatible: print("Model is not compatible with Akida!") sys.exit(1) #! Quantize model to 4/4/8 akida_model = quantize_function(keras_model=model) #! Perform quantization-aware training akida_model = qat_function(akida_model=akida_model, train_dataset=train_segmentation_dataset, validation_dataset=validation_segmentation_dataset, optimizer=opt, fine_tune_loss=weighted_xent, fine_tune_metrics=None, callbacks=callbacks + [centroid_callback, print_callback], stopping_metric='val_f1', fit_verbose=0) return model, akida_model EPOCHS = args.epochs or 30 LEARNING_RATE = args.learning_rate or 0.001 BATCH_SIZE = args.batch_size or 32 import tensorflow as tf def akida_quantize_model( keras_model, weight_quantization: int = 4, activ_quantization: int = 4, input_weight_quantization: int = 8, ): import cnn2snn print("Performing post-training quantization...") akida_model = cnn2snn.quantize( keras_model, weight_quantization=weight_quantization, activ_quantization=activ_quantization, input_weight_quantization=input_weight_quantization, ) print("Performing post-training quantization OK") print("") return akida_model def akida_perform_qat( akida_model, train_dataset: tf.data.Dataset, validation_dataset: tf.data.Dataset, optimizer: str, fine_tune_loss: str, fine_tune_metrics: "list[str]", callbacks, stopping_metric: str = "val_accuracy", fit_verbose: int = 2, qat_epochs: int = 30, ): early_stopping = tf.keras.callbacks.EarlyStopping( monitor=stopping_metric, mode="max", verbose=1, min_delta=0, patience=10, restore_best_weights=True, ) callbacks.append(early_stopping) print("Running quantization-aware training...") akida_model.compile( optimizer=optimizer, loss=fine_tune_loss, metrics=fine_tune_metrics ) akida_model.fit( train_dataset, epochs=qat_epochs, verbose=fit_verbose, validation_data=validation_dataset, callbacks=callbacks, ) print("Running quantization-aware training OK") print("") return akida_model model, akida_model = train(num_classes=classes, learning_rate=LEARNING_RATE, num_epochs=EPOCHS, alpha=0.5, object_weight=100, train_dataset=train_dataset, validation_dataset=validation_dataset, best_model_path=BEST_MODEL_PATH, input_shape=MODEL_INPUT_SHAPE, callbacks=callbacks, quantize_function=akida_quantize_model, qat_function=akida_perform_qat, batch_size=BATCH_SIZE, ensure_determinism=ensure_determinism) disable_per_channel_quantization = False
Input layer (49,152 features)
Akida FOMO (Faster Objects, More Objects) AkidaNet (alpha=0.5 @224x224x3)
Output layer (2 classes)

Model

Model version: