Training settings
Please provide a valid training processor option
Neural network architecture
import sys
sys.path.append('./resources/libraries')
import os
import tensorflow as tf
import numpy as np
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.layers import BatchNormalization, Conv2D, Softmax
from tensorflow.keras.models import Model
from ei_tensorflow.constrained_object_detection import dataset, metrics, util
from ei_tensorflow.velo import train_keras_model_with_velo
from ei_shared.pretrained_weights import get_or_download_pretrained_weights
import ei_tensorflow.training
WEIGHTS_PREFIX = os.environ.get('WEIGHTS_PREFIX', os.getcwd())
def build_model(input_shape: tuple, weights: str, alpha: float,
num_classes_with_background: int) -> tf.keras.Model:
"""Construct a constrained object detection model.
Args:
input_shape: Passed to MobileNetV2 construction.
weights: Weights for initialization of MobileNetV2 where None implies
random initialization.
alpha: MobileNetV2 alpha value.
num_classes_with_background: Total number of classes including background.
Returns:
Uncompiled Keras model.
Model takes (B, H, W, C) input and
returns (B, H//8, W//8, num_classes_with_background) logits.
"""
# Create full MobileNetV2 from (H, W, C) input to (H/8, W/8, C) output
mobile_net_v2 = MobileNetV2(input_shape=input_shape,
weights=weights,
alpha=alpha,
include_top=True)
# Speed up batch normalization layers
for layer in mobile_net_v2.layers:
if isinstance(layer, BatchNormalization):
layer.momentum = 0.9
# Cut MobileNetV2 at 1/8th input resolution
cut_point = mobile_net_v2.get_layer('block_6_expand_relu')
# Attach a small additional head on MobileNetV2
model = Conv2D(filters=32, kernel_size=1, strides=1,
activation='relu', name='head')(cut_point.output)
logits = Conv2D(filters=num_classes_with_background, kernel_size=1, strides=1,
activation=None, name='logits')(model)
return Model(inputs=mobile_net_v2.input, outputs=logits)
def construct_weighted_xent_fn_per_pixel(model_output_shape, class_weights):
"""Construct a custom loss function with per-pixel class weights.
Args:
model_output_shape: Output shape of the model.
class_weights: List of weights per class (including background class).
Returns:
Loss function suitable for use with Keras model.
"""
if len(model_output_shape) != 4:
raise Exception("Expected model_output_shape of form (BATCH_SIZE, H, W, NUM_CLASSES)")
_batch_size, height, width, num_classes_model = model_output_shape
if num_classes_model != len(class_weights):
raise Exception(f"Number of class weights ({len(class_weights)}) does not match "
f"number of classes ({num_classes_model})")
class_weights_tensor = tf.constant(class_weights, dtype=tf.float32)
def weighted_xent(y_true, y_pred_logits):
# Convert y_true from one-hot encoding to class indices
class_indices = tf.argmax(y_true, axis=-1) # Shape: (batch_size, height, width)
# Compute the per-pixel weights
pixel_weights = tf.gather(class_weights_tensor, class_indices) # Shape: (batch_size, height, width)
# Compute the per-pixel loss using sparse_softmax_cross_entropy_with_logits
losses = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=class_indices, logits=y_pred_logits)
# Multiply the per-pixel loss by the per-pixel weights
weighted_losses = losses * pixel_weights # Shape: (batch_size, height, width)
# Compute mean loss over all pixels and batches
return tf.reduce_mean(weighted_losses)
return weighted_xent
def train(num_classes: int, learning_rate: float, num_epochs: int,
alpha: float,
train_dataset: tf.data.Dataset,
validation_dataset: tf.data.Dataset,
best_model_path: str,
input_shape: tuple,
batch_size: int,
class_weights,
use_velo: bool = False,
ensure_determinism: bool = False) -> tf.keras.Model:
"""Construct and train a constrained object detection model with per-pixel class weighting.
Args:
num_classes: Number of classes excluding the background.
learning_rate: Learning rate for Adam optimizer.
num_epochs: Number of epochs for training.
alpha: Alpha value for MobileNetV2.
train_dataset: Training dataset.
validation_dataset: Validation dataset.
best_model_path: Path to save the best model.
input_shape: Shape of the model's input.
batch_size: Training batch size.
class_weights: List or array of per-class weights (including background class).
use_velo: Whether to use VeLO optimizer.
ensure_determinism: If true, disables non-deterministic functions.
Returns:
Trained Keras model.
"""
# Initialize callbacks if not already defined
global callbacks
callbacks = callbacks if 'callbacks' in globals() else []
num_classes_with_background = num_classes + 1
width, height, input_num_channels = input_shape
if width != height:
raise Exception(f"Only square inputs are supported; not {input_shape}")
# Use pretrained weights if available
allowed_combinations = [{'num_channels': 1, 'alpha': 0.1},
{'num_channels': 1, 'alpha': 0.35},
{'num_channels': 3, 'alpha': 0.1},
{'num_channels': 3, 'alpha': 0.35}]
weights = get_or_download_pretrained_weights(
WEIGHTS_PREFIX, input_num_channels, alpha, allowed_combinations)
model = build_model(
input_shape=input_shape,
weights=weights,
alpha=alpha,
num_classes_with_background=num_classes_with_background
)
# Derive output size from model
model_output_shape = model.layers[-1].output.shape
_batch, output_width, output_height, num_classes_model = model_output_shape
if output_width != output_height:
raise Exception(f"Only square outputs are supported; not {model_output_shape}")
# Build the custom per-pixel weighted loss function
weighted_xent = construct_weighted_xent_fn_per_pixel(model_output_shape, class_weights)
prefetch_policy = 1 if ensure_determinism else tf.data.experimental.AUTOTUNE
# Transform bounding box labels into segmentation maps
def as_segmentation(ds, shuffle):
ds = ds.map(dataset.bbox_to_segmentation(
output_width, num_classes_with_background))
if not ensure_determinism and shuffle:
ds = ds.shuffle(buffer_size=batch_size * 4)
ds = ds.batch(batch_size, drop_remainder=False).prefetch(prefetch_policy)
return ds
train_segmentation_dataset = as_segmentation(train_dataset, True)
validation_segmentation_dataset = as_segmentation(validation_dataset, False)
validation_dataset_for_callback = (validation_dataset
.batch(batch_size, drop_remainder=False)
.prefetch(prefetch_policy))
# Initialize biases of final classifier based on training data prior
util.set_classifier_biases_from_dataset(
model, train_segmentation_dataset)
if not use_velo:
model.compile(loss=weighted_xent,
optimizer=Adam(learning_rate=learning_rate))
# Create callbacks
callbacks.append(metrics.CentroidScoring(validation_dataset_for_callback,
output_width, num_classes_with_background))
callbacks.append(metrics.PrintPercentageTrained(num_epochs))
# Model checkpointing based on the best validation F1 score
callbacks.append(
tf.keras.callbacks.ModelCheckpoint(best_model_path,
monitor='val_f1', save_best_only=True, mode='max',
save_weights_only=True, verbose=0))
if use_velo:
from tensorflow.python.framework.errors_impl import ResourceExhaustedError
try:
train_keras_model_with_velo(
model,
train_segmentation_dataset,
validation_segmentation_dataset,
loss_fn=weighted_xent,
num_epochs=num_epochs,
callbacks=callbacks
)
except ResourceExhaustedError as e:
print(str(e))
raise Exception(
"ResourceExhaustedError caught during train_keras_model_with_velo."
" Though VeLO encourages a large batch size, the current"
f" size of {batch_size} may be too large. Please try a lower"
" value. For further assistance please contact support"
" at https://forum.edgeimpulse.com/")
else:
model.fit(train_segmentation_dataset,
validation_data=validation_segmentation_dataset,
epochs=num_epochs, callbacks=callbacks, verbose=0)
# Restore best weights
model.load_weights(best_model_path)
# Add explicit softmax layer before export (for inference)
softmax_layer = Softmax()(model.layers[-1].output)
model = Model(model.input, softmax_layer)
return model
# Training parameters
EPOCHS = args.epochs or 60
LEARNING_RATE = args.learning_rate or 0.001
BATCH_SIZE = args.batch_size or 32
alpha = 0.35 # Adjust as needed
def get_num_classes_from_dataset(dataset):
num_classes = 0
for sample in dataset:
labels = sample[1] # labels is a tuple of two RaggedTensors
class_vectors = labels[1] # RaggedTensor of class vectors (one-hot)
# Check if class_vectors is empty
if tf.shape(class_vectors)[0] == 0:
# No class vectors in this sample, skip to the next one
continue
class_vectors_tensor = class_vectors.to_tensor()
num_classes_in_sample = class_vectors_tensor.shape[1]
if num_classes_in_sample > num_classes:
num_classes = num_classes_in_sample
if num_classes == 0:
raise Exception("Could not determine number of classes from dataset.")
return num_classes
# Get num_classes from the training dataset
num_classes = get_num_classes_from_dataset(train_dataset)
num_classes_with_background = num_classes + 1
# Use the same allowed combinations as before
allowed_combinations = [{'num_channels': 1, 'alpha': 0.1},
{'num_channels': 1, 'alpha': 0.35},
{'num_channels': 3, 'alpha': 0.1},
{'num_channels': 3, 'alpha': 0.35}]
weights = get_or_download_pretrained_weights(
WEIGHTS_PREFIX, MODEL_INPUT_SHAPE[2], alpha, allowed_combinations)
# Build the model to get output_width and output_height
model = build_model(
input_shape=MODEL_INPUT_SHAPE,
weights=weights,
alpha=alpha,
num_classes_with_background=num_classes_with_background
)
# Derive output size from model
model_output_shape = model.layers[-1].output.shape
_batch, output_width, output_height, num_classes_model = model_output_shape
if output_width != output_height:
raise Exception(f"Only square outputs are supported; not {model_output_shape}")
prefetch_policy = 1 if ensure_determinism else tf.data.experimental.AUTOTUNE
# Transform bounding box labels into segmentation maps
def as_segmentation(ds, shuffle):
ds = ds.map(dataset.bbox_to_segmentation(
output_width, num_classes_with_background))
if not ensure_determinism and shuffle:
ds = ds.shuffle(buffer_size=BATCH_SIZE * 4)
ds = ds.batch(BATCH_SIZE, drop_remainder=False).prefetch(prefetch_policy)
return ds
train_segmentation_dataset = as_segmentation(train_dataset, True)
validation_segmentation_dataset = as_segmentation(validation_dataset, False)
# Function to compute class frequencies from the segmentation dataset
def compute_class_pixel_frequencies_from_segmentation_dataset(segmentation_dataset, num_classes_with_background):
class_counts = np.zeros(num_classes_with_background, dtype=np.float64)
total_pixels = 0
for images, y_trues in segmentation_dataset:
# y_trues shape: (batch_size, H, W, num_classes_with_background)
# Sum over batch and spatial dimensions
counts = tf.reduce_sum(y_trues, axis=[0, 1, 2]).numpy() # Shape: (num_classes_with_background,)
class_counts += counts
total_pixels += np.prod(y_trues.shape[:3]) # batch_size * H * W
return class_counts, total_pixels
# Compute class frequencies from the training segmentation dataset
class_counts, total_pixels = compute_class_pixel_frequencies_from_segmentation_dataset(
train_segmentation_dataset, num_classes_with_background)
# Compute initial class weights inversely proportional to class frequencies
initial_class_weights = [
(total_pixels / (num_classes_with_background * count)) if count > 0 else 0.0
for count in class_counts
]
# Optionally, apply an extra weight factor to adjust class weights
# Since the background is dominant, and you prefer not to set its weight to 1.0,
# we compute its weight based on the class frequencies without any adjustments.
class_weights = initial_class_weights # Use the computed weights directly
print(f"Class counts (pixels): {class_counts}")
print(f"Total pixels: {total_pixels}")
print(f"Class weights: {class_weights}")
# Proceed with training
model = train(
num_classes=num_classes, # Excluding background class
learning_rate=LEARNING_RATE,
num_epochs=EPOCHS,
alpha=alpha,
train_dataset=train_dataset,
validation_dataset=validation_dataset,
best_model_path=BEST_MODEL_PATH,
input_shape=MODEL_INPUT_SHAPE,
batch_size=BATCH_SIZE,
class_weights=class_weights,
use_velo=False,
ensure_determinism=ensure_determinism
)
disable_per_channel_quantization = False
Input layer (49,152 features)
FOMO (Faster Objects, More Objects) MobileNetV2 0.35
Output layer (4 classes)
Model
Model version: