# Fine-Tuning the Audio Spectrogram Transformer (AST) for Audio Classification

This Jupyter Notebook provides a comprehensive guide for fine-tuning the Audio Spectrogram Transformer (AST) model on your own audio classification dataset using tools from the HuggingFace ecosystem and PyTorch. The notebook covers the entire workflow, including data loading, preprocessing, applying audio augmentations, configuring the model, and setting up the training process.

**Published:** 30.07.2024  
**Author:** Marius Steger  
**Email:** [marius.steger@renumics.com](mailto:marius.steger@renumics.com)  
**Organization:** [Renumics](https://renumics.com/)  

## Step 1: Install Required Packages
Before we start, install all the required packages.

In [1]:
!pip install transformers[torch] datasets[audio] audiomentations


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.1[0m[39;49m -> [0m[32;49m24.1.2[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m


## Step 2: Load Your Data in the Correct Format

In [2]:
from datasets import Dataset, Audio, ClassLabel, Features, load_dataset

In [3]:
# Define class labels
#class_labels = ClassLabel(names=["bang", "dog_bark"])

# Define features with audio and label columns
#features = Features({
#    "audio": Audio(),
#    "labels": class_labels
#})

# Load data (example with a dictionary)
#dataset = Dataset.from_dict({
#    "audio": ["/audio/fold1/7061-6-0-0.wav", "/audio/fold1/7383-3-0-0.wav"],
#    "labels": [0, 1],
#}, features=features)

In [4]:
# Load a pre-existing dataset from the HuggingFace Hub
esc50 = load_dataset("ashraq/esc50", split="train")

Repo card metadata block was not found. Setting CardData to empty.


## Step 3: Preprocess the Audio Data

In [5]:
import numpy as np
from datasets import Audio, ClassLabel
from transformers import ASTFeatureExtractor

In [6]:
# get target value - class name mappings
df = esc50.select_columns(["target", "category"]).to_pandas()
class_names = df.iloc[np.unique(df["target"], return_index=True)[1]]["category"].to_list()

# cast target and audio column
esc50 = esc50.cast_column("target", ClassLabel(names=class_names))
esc50 = esc50.cast_column("audio", Audio(sampling_rate=16000))

# rename the target feature
esc50 = esc50.rename_column("target", "labels")
num_labels = len(np.unique(esc50["labels"]))

In [7]:
# Define the pretrained model and instantiate the feature extractor
pretrained_model = "MIT/ast-finetuned-audioset-10-10-0.4593"
feature_extractor = ASTFeatureExtractor.from_pretrained(pretrained_model)
model_input_name = feature_extractor.model_input_names[0]
SAMPLING_RATE = feature_extractor.sampling_rate

In [8]:
# Preprocessing function
def preprocess_audio(batch):
    wavs = [audio["array"] for audio in batch["input_values"]]
    inputs = feature_extractor(wavs, sampling_rate=SAMPLING_RATE, return_tensors="pt")
    return {model_input_name: inputs.get(model_input_name), "labels": list(batch["labels"])}

In [9]:
# we use the esc50 train split for this tutorial on how to fine-tune the AST Model
dataset = esc50
label2id = dataset.features["labels"]._str2int  # we add the mapping from INTs to STRINGs

In [10]:
# split training data
if "test" not in dataset:
    dataset = dataset.train_test_split(
        test_size=0.2, shuffle=True, seed=0, stratify_by_column="labels")

## Step 4: Add Audio Augmentations

In [11]:
import torch
from audiomentations import Compose, AddGaussianSNR, GainTransition, Gain, ClippingDistortion, TimeStretch, PitchShift

In [12]:
# Define audio augmentations
audio_augmentations = Compose([
    AddGaussianSNR(min_snr_db=10, max_snr_db=20),
    Gain(min_gain_db=-6, max_gain_db=6),
    GainTransition(min_gain_db=-6, max_gain_db=6, min_duration=0.01, max_duration=0.3, duration_unit="fraction"),
    ClippingDistortion(min_percentile_threshold=0, max_percentile_threshold=30, p=0.5),
    TimeStretch(min_rate=0.8, max_rate=1.2),
    PitchShift(min_semitones=-4, max_semitones=4),
], p=0.8, shuffle=True)

In [13]:
# Preprocessing with augmentations
def preprocess_audio_with_transforms(batch):
    wavs = [audio_augmentations(audio["array"], sample_rate=SAMPLING_RATE) for audio in batch["input_values"]]
    inputs = feature_extractor(wavs, sampling_rate=SAMPLING_RATE, return_tensors="pt")
    return {model_input_name: inputs.get(model_input_name), "labels": list(batch["labels"])}

In [14]:
dataset = dataset.cast_column("audio", Audio(sampling_rate=feature_extractor.sampling_rate))
dataset = dataset.rename_column("audio", "input_values")

In [15]:
# calculate values for normalization
feature_extractor.do_normalize = False  # we set normalization to False in order to calculate the mean + std of the dataset
mean = []
std = []

# we use the transformation w/o augmentation on the training dataset to calculate the mean + std
dataset["train"].set_transform(preprocess_audio, output_all_columns=False)
for i, (audio_input, labels) in enumerate(dataset["train"]):
    cur_mean = torch.mean(dataset["train"][i][audio_input])
    cur_std = torch.std(dataset["train"][i][audio_input])
    mean.append(cur_mean)
    std.append(cur_std)

feature_extractor.mean = np.mean(mean)
feature_extractor.std = np.mean(std)
feature_extractor.do_normalize = True

print("Calculated mean and std:", feature_extractor.mean, feature_extractor.std)

Calculated mean and std: -3.3504603 4.387065


In [16]:
# Apply transforms
dataset["train"].set_transform(preprocess_audio_with_transforms, output_all_columns=False)
dataset["test"].set_transform(preprocess_audio, output_all_columns=False)

## Step 5: Configure and Initialize the AST for Fine-Tuning

In [17]:
import evaluate
from transformers import ASTConfig, ASTForAudioClassification, TrainingArguments, Trainer

In [18]:
# Load configuration from the pretrained model
config = ASTConfig.from_pretrained(pretrained_model)
config.num_labels = num_labels
config.label2id = label2id
config.id2label = {v: k for k, v in label2id.items()}

In [19]:
# Initialize the model with the updated configuration
model = ASTForAudioClassification.from_pretrained(pretrained_model, config=config, ignore_mismatched_sizes=True)
model.init_weights()

Some weights of ASTForAudioClassification were not initialized from the model checkpoint at MIT/ast-finetuned-audioset-10-10-0.4593 and are newly initialized because the shapes did not match:
- classifier.dense.bias: found shape torch.Size([527]) in the checkpoint and torch.Size([50]) in the model instantiated
- classifier.dense.weight: found shape torch.Size([527, 768]) in the checkpoint and torch.Size([50, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


### Setup Metrics and Start Training

In [20]:
# Configure training arguments
training_args = TrainingArguments(
    output_dir=f"./runs/ast_classifier",
    logging_dir=f"./logs/ast_classifier",
    report_to="tensorboard",
    learning_rate=5e-5,  # LEARNING RATE
    push_to_hub=False,
    num_train_epochs=10,  # EPOCHS
    per_device_train_batch_size=8,  # BATCH SIZE
    eval_strategy="epoch",
    save_strategy="epoch",
    eval_steps=1,
    save_steps=1,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",  # eval_+metric ist utilized
    logging_strategy="steps",
    logging_steps=20,
)

In [21]:
# Define evaluation metrics
accuracy = evaluate.load("accuracy")
recall = evaluate.load("recall")
precision = evaluate.load("precision")
f1 = evaluate.load("f1")

AVERAGE = "macro" if config.num_labels > 2 else "binary"

# setup metrics function
def compute_metrics(eval_pred):
    # get predictions and scores
    logits = eval_pred.predictions
    predictions = np.argmax(logits, axis=1)

    # compute metrics
    metrics = accuracy.compute(predictions=predictions, references=eval_pred.label_ids)
    metrics.update(precision.compute(predictions=predictions, references=eval_pred.label_ids, average=AVERAGE))
    metrics.update(recall.compute(predictions=predictions, references=eval_pred.label_ids, average=AVERAGE))
    metrics.update(f1.compute(predictions=predictions, references=eval_pred.label_ids, average=AVERAGE))

    return metrics

In [22]:
# setup trainer
trainer = Trainer(
    model=model,
    args=training_args,  # we use our configured training arguments
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    compute_metrics=compute_metrics,  # we the metrics function from above
)

In [23]:
# start a training
trainer.train()



Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,1.0649,0.557234,0.8525,0.888413,0.8525,0.844813
2,0.4705,0.292247,0.9025,0.918432,0.9025,0.902523
3,0.4024,0.295535,0.92,0.935807,0.92,0.921613
4,0.4021,0.304958,0.9275,0.938123,0.9275,0.926955
5,0.2296,0.268698,0.915,0.92573,0.915,0.911636
6,0.2256,0.191082,0.9475,0.953143,0.9475,0.947632
7,0.1784,0.268304,0.93,0.938053,0.93,0.925158
8,0.0802,0.175834,0.95,0.956268,0.95,0.950184
9,0.082,0.141045,0.9575,0.962,0.9575,0.95741
10,0.0175,0.133586,0.955,0.9595,0.955,0.9551


Non-default generation parameters: {'max_length': 1024}
Non-default generation parameters: {'max_length': 1024}
Non-default generation parameters: {'max_length': 1024}
Non-default generation parameters: {'max_length': 1024}
Non-default generation parameters: {'max_length': 1024}
Non-default generation parameters: {'max_length': 1024}
Non-default generation parameters: {'max_length': 1024}
Non-default generation parameters: {'max_length': 1024}
Non-default generation parameters: {'max_length': 1024}
Non-default generation parameters: {'max_length': 1024}
Non-default generation parameters: {'max_length': 1024}


TrainOutput(global_step=2000, training_loss=0.4032806022465229, metrics={'train_runtime': 734.9587, 'train_samples_per_second': 21.77, 'train_steps_per_second': 2.721, 'total_flos': 1.084989898752e+18, 'train_loss': 0.4032806022465229, 'epoch': 10.0})