import tensorflow as tf from tensorflow.keras import Model from tensorflow.keras.layers import Dense, Input, MultiHeadAttention, Reshape, Dropout, GlobalAveragePooling1D, Conv1D, LayerNormalization, Normalization from tensorflow.keras.optimizers import Adam EPOCHS = 10 LEARNING_RATE = 0.0005 BATCH_SIZE = 32 # model architecture def transformer_encoder(inputs, head_size, num_heads, ff_dim, dropout=0): # Normalization and Attention x = LayerNormalization(epsilon=1e-6)(inputs) x = MultiHeadAttention( key_dim=head_size, num_heads=num_heads, dropout=dropout )(x, x) x = Dropout(dropout)(x) res = x + inputs # Feed Forward Part x = LayerNormalization(epsilon=1e-6)(res) x = Conv1D(filters=ff_dim, kernel_size=1, activation="relu")(x) x = Dropout(dropout)(x) x = Conv1D(filters=inputs.shape[-1], kernel_size=1)(x) return x + res def build_model( input_shape, head_size, num_heads, ff_dim, num_transformer_blocks, mlp_units, dropout=0, mlp_dropout=0, ): inputs = Input(shape=input_shape) x = Reshape([int(input_length/3), 3])(inputs) # pre-calculated mean and variance x = Normalization(axis=-1, mean=[-0.047443, -6.846333, -1.057524], variance=[16.179484, 33.019396, 22.892909])(x) for _ in range(num_transformer_blocks): x = transformer_encoder(x, head_size, num_heads, ff_dim, dropout) x = GlobalAveragePooling1D(data_format="channels_first")(x) for dim in mlp_units: x = Dense(dim, activation="relu")(x) x = Dropout(mlp_dropout)(x) outputs = Dense(classes, activation="softmax")(x) return Model(inputs, outputs) input_shape = (input_length, ) model = build_model( input_shape, head_size=64, num_heads=2, ff_dim=4, num_transformer_blocks=1, mlp_units=[32], mlp_dropout=0.40, dropout=0.25, ) # this controls the learning rate opt = Adam(learning_rate=LEARNING_RATE, beta_1=0.9, beta_2=0.999) callbacks.append(BatchLoggerCallback(BATCH_SIZE, train_sample_count, epochs=EPOCHS)) # train the neural network model.compile(loss='categorical_crossentropy', optimizer=opt, metrics=['accuracy']) model.summary() train_dataset = train_dataset.batch(BATCH_SIZE, drop_remainder=False) validation_dataset = validation_dataset.batch(BATCH_SIZE, drop_remainder=False) model.fit(train_dataset, epochs=EPOCHS, validation_data=validation_dataset, verbose=2, callbacks=callbacks) disable_per_channel_quantization = False
