This is an automated email from the ASF dual-hosted git repository. damccorm pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/beam.git
The following commit(s) were added to refs/heads/master by this push: new 5221fde659a Add Speech Emotion Recognition TensorFlow Notebook (#28172) 5221fde659a is described below commit 5221fde659aa96b88dac78bd507030ef788454d2 Author: Reeba Qureshi <64488642+reeba...@users.noreply.github.com> AuthorDate: Fri Sep 1 17:52:14 2023 +0530 Add Speech Emotion Recognition TensorFlow Notebook (#28172) --- .../beam-ml/speech_emotion_tensorflow.ipynb | 2159 ++++++++++++++++++++ 1 file changed, 2159 insertions(+) diff --git a/examples/notebooks/beam-ml/speech_emotion_tensorflow.ipynb b/examples/notebooks/beam-ml/speech_emotion_tensorflow.ipynb new file mode 100644 index 00000000000..098cb150bfd --- /dev/null +++ b/examples/notebooks/beam-ml/speech_emotion_tensorflow.ipynb @@ -0,0 +1,2159 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "kNv8XQ6-TM7W" + }, + "outputs": [], + "source": [ + "# @title ###### Licensed to the Apache Software Foundation (ASF), Version 2.0 (the \"License\")\n", + "\n", + "# Licensed to the Apache Software Foundation (ASF) under one\n", + "# or more contributor license agreements. See the NOTICE file\n", + "# distributed with this work for additional information\n", + "# regarding copyright ownership. The ASF licenses this file\n", + "# to you under the Apache License, Version 2.0 (the\n", + "# \"License\"); you may not use this file except in compliance\n", + "# with the License. You may obtain a copy of the License at\n", + "#\n", + "# http://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing,\n", + "# software distributed under the License is distributed on an\n", + "# \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n", + "# KIND, either express or implied. See the License for the\n", + "# specific language governing permissions and limitations\n", + "# under the License" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "pjsWzeiITPd6" + }, + "source": [ + "# Speech Emotion Recognition using Apache Beam\n", + "\n", + "<table align=\"left\">\n", + " <td>\n", + " <a target=\"_blank\" href=\"https://colab.sandbox.google.com/github/apache/beam/blob/master/examples/notebooks/beam-ml/speech_emotion_tensorflow.ipynb\"><img src=\"https://raw.githubusercontent.com/google/or-tools/main/tools/colab_32px.png\" />Run in Google Colab</a>\n", + " </td>\n", + " <td>\n", + " <a target=\"_blank\" href=\"https://github.com/apache/beam/blob/master/examples/notebooks/beam-ml/speech_emotion_tensorflow.ipynb\"><img src=\"https://raw.githubusercontent.com/google/or-tools/main/tools/github_32px.png\" />View source on GitHub</a>\n", + " </td>\n", + "</table>" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "h3nLfqdsTZdZ" + }, + "source": [ + "Speech Emotion Classification is a machine learning technique that deciphers emotions from audio data. It involves data augmentation, feature extraction, preprocessing and training an appropriate model. For structured workflow design, Apache Beam is a suitable tool. This notebook showcases Apache Beam's use in speech emotion classification and achieves the following:\n", + "\n", + "* Imports and processes the CREMA-D dataset for speech emotion analysis.\n", + "* Perform various data augmentation and feature extraction techniques using the [Librosa](https://librosa.org/doc/latest/index.html) library.\n", + "* Develops a TensorFlow model to classify emotions.\n", + "* Stores the trained model.\n", + "* Constructs a Beam pipeline that:\n", + " * Creates a PCollection of audio samples.\n", + " * Applies preprocessing transforms.\n", + " * Utilizes the trained model to predict emotions.\n", + " * Stores the emotion predictions.\n", + "\n", + "For more insights into leveraging Apache Beam for machine learning pipelines, explore [AI/ML Pipelines using Beam](https://beam.apache.org/documentation/ml/overview/)." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "s9xU1ws-DZwp" + }, + "source": [ + "## Installing Apache Beam" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "RbByEyZPMgbw", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "60db9d24-a8f5-4c40-dd6e-66878b7ce76b" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m14.6/14.6 MB\u001b[0m \u001b[31m89.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m89.7/89.7 kB\u001b[0m \u001b[31m10.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m139.9/139.9 kB\u001b[0m \u001b[31m17.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m152.0/152.0 kB\u001b[0m \u001b[31m19.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.7/2.7 MB\u001b[0m \u001b[31m92.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m43.4/43.4 kB\u001b[0m \u001b[31m4.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m671.3/671.3 kB\u001b[0m \u001b[31m57.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.7/2.7 MB\u001b[0m \u001b[31m105.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m300.4/300.4 kB\u001b[0m \u001b[31m32.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25h Building wheel for crcmod (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + " Building wheel for dill (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + " Building wheel for hdfs (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + " Building wheel for docopt (setup.py) ... \u001b[?25l\u001b[?25hdone\n" + ] + } + ], + "source": [ + "!pip install apache_beam --quiet" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "dNg2hWBqrGwz" + }, + "source": [ + "\n", + "## Importing necessary libraries\n", + "\n", + "Here is a brief overview of the libraries imported:\n", + "* **[os](https://docs.python.org/3/library/os.html)**: Used for file and directory operations.\n", + "* **[NumPy](https://numpy.org/doc/stable/)**: Allows efficient numerical manipulation of arrays.\n", + "* **[Pandas](https://pandas.pydata.org/docs/)**: Facilitates data manipulation and analysis.\n", + "* **[Librosa](https://librosa.org/doc/latest/index.html)**: Provides tools for analyzing and working with audio data.\n", + "* **[IPython](https://ipython.readthedocs.io/en/stable/index.html)**: Creates visualizations for multimedia content. Here we have used it for playing audio files.\n", + "* **[Sklearn](https://scikit-learn.org/stable/index.html)**: Offers comprehensive tools for Machine Learning. Here we have used it for preprocessing and splitting the data.\n", + "* **[TensorFlow](https://www.tensorflow.org/api_docs)** and **[Keras](https://keras.io/api/)**: Enables building and training complex Machine Learning and Deep Learning model.\n", + "* **[TFModelHandlerNumpy](https://beam.apache.org/documentation/sdks/python-machine-learning/#tensorflow)**: Defines the configuration used to load/use the model that we train. We use TFModelHandlerNumpy because the model was trained with TensorFlow and takes numpy arrays as input.\n", + "* **[RunInference](https://beam.apache.org/releases/pydoc/current/apache_beam.ml.inference.html#apache_beam.ml.inference.RunInference)**: Loads the model and obtains predictions as part of the Apache Beam pipeline. For more information, see docs on prediction and inference.\n", + "* **[Apache Beam](https://beam.apache.org/documentation/)**: Builds a pipeline for Image Processing." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "CzQkOP4v-X0Z" + }, + "outputs": [], + "source": [ + "import os\n", + "\n", + "import numpy as np\n", + "import pandas as pd\n", + "\n", + "import librosa\n", + "from IPython.display import Audio\n", + "\n", + "from sklearn.model_selection import train_test_split\n", + "from sklearn.preprocessing import StandardScaler, OneHotEncoder\n", + "\n", + "import tensorflow as tf\n", + "from tensorflow import keras\n", + "from tensorflow.python.keras.callbacks import EarlyStopping, ReduceLROnPlateau\n", + "\n", + "from keras import layers\n", + "from keras import models\n", + "from keras.utils import np_utils\n", + "from keras.models import Sequential\n", + "from keras.utils import np_utils, to_categorical\n", + "from keras.callbacks import ModelCheckpoint\n", + "\n", + "from apache_beam.ml.inference.tensorflow_inference import TFModelHandlerNumpy\n", + "from apache_beam.ml.inference.base import RunInference\n", + "import apache_beam as beam" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "dxAtTL2VIztQ" + }, + "source": [ + "## Importing dataset from Google Drive\n", + "\n", + "[CREMA-D](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4313618/) is a dataset that contains a collection of 7442 audio recordings of actors portraying different emotions. The dataset can be downloaded from [Kaggle](https://www.kaggle.com/datasets/ejlok1/cremad). As it is large in size, it will be inconvenient to upload it on Colab every time we want to run the notebook. Instead, we have uploaded the dataset on Google Drive after downloading it from Kaggle. Then we can access it [...] + "\n", + "Please ensure if you are following this method, then your Colab notebook must be created with the same Google account in which the folder is stored." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "PmaLie0lOI0g", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "e86b0467-8ab3-4ec8-e5cd-40e015ee272e" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Mounted at /content/gdrive\n" + ] + } + ], + "source": [ + "from google.colab import drive\n", + "drive.mount('/content/gdrive', force_remount=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "W3ESVLNUvyqG" + }, + "source": [ + "Here we create a path for the folder in Google Drive containing the audios to access them." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "0BWrk6bn91Uu" + }, + "outputs": [], + "source": [ + "root_dir = \"/content/gdrive/My Drive/\"\n", + "Crema = root_dir + 'CREMA/'" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "VoROV5Vyvt-g" + }, + "source": [ + "Using the os library, we can list all the audio files in the Google Drive folder" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "7EXG23Yl-TfG", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "d5b539c0-75a6-4ae3-be56-8a074539cd98" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "['1079_TIE_NEU_XX.wav',\n", + " '1079_TIE_SAD_XX.wav',\n", + " '1079_TSI_ANG_XX.wav',\n", + " '1079_TSI_DIS_XX.wav',\n", + " '1079_TSI_HAP_XX.wav',\n", + " '1079_TSI_FEA_XX.wav',\n", + " '1079_TSI_NEU_XX.wav',\n", + " '1079_TSI_SAD_XX.wav',\n", + " '1079_WSI_ANG_XX.wav',\n", + " '1079_WSI_DIS_XX.wav']" + ] + }, + "metadata": {}, + "execution_count": 6 + } + ], + "source": [ + "os.chdir(Crema)\n", + "os.listdir()[:10] # Listing the first 10 audio files" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3HFYUtJ5JYS_" + }, + "source": [ + "## Creating a DataFrame\n", + "We will create a DataFrame with two columns, path and emotion:\n", + "* Path: This will contain the path to a specific audio file in the directory.\n", + "* Emotion: This is the label which will state the emotion of an audio file.\n", + "\n", + "The emotion can be extracted from the audio file name." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "MVy9nx56-edb", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 206 + }, + "outputId": "bc5e89dd-78e6-458b-9728-5030902f837e" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + " Emotion Path\n", + "0 neutral /content/gdrive/My Drive/CREMA//1079_TIE_NEU_X...\n", + "1 sad /content/gdrive/My Drive/CREMA//1079_TIE_SAD_X...\n", + "2 angry /content/gdrive/My Drive/CREMA//1079_TSI_ANG_X...\n", + "3 disgust /content/gdrive/My Drive/CREMA//1079_TSI_DIS_X...\n", + "4 happy /content/gdrive/My Drive/CREMA//1079_TSI_HAP_X..." + ], + "text/html": [ + "\n", + " <div id=\"df-e283afc1-5e3c-49ce-a91c-3dbc29a1b108\" class=\"colab-df-container\">\n", + " <div>\n", + "<style scoped>\n", + " .dataframe tbody tr th:only-of-type {\n", + " vertical-align: middle;\n", + " }\n", + "\n", + " .dataframe tbody tr th {\n", + " vertical-align: top;\n", + " }\n", + "\n", + " .dataframe thead th {\n", + " text-align: right;\n", + " }\n", + "</style>\n", + "<table border=\"1\" class=\"dataframe\">\n", + " <thead>\n", + " <tr style=\"text-align: right;\">\n", + " <th></th>\n", + " <th>Emotion</th>\n", + " <th>Path</th>\n", + " </tr>\n", + " </thead>\n", + " <tbody>\n", + " <tr>\n", + " <th>0</th>\n", + " <td>neutral</td>\n", + " <td>/content/gdrive/My Drive/CREMA//1079_TIE_NEU_X...</td>\n", + " </tr>\n", + " <tr>\n", + " <th>1</th>\n", + " <td>sad</td>\n", + " <td>/content/gdrive/My Drive/CREMA//1079_TIE_SAD_X...</td>\n", + " </tr>\n", + " <tr>\n", + " <th>2</th>\n", + " <td>angry</td>\n", + " <td>/content/gdrive/My Drive/CREMA//1079_TSI_ANG_X...</td>\n", + " </tr>\n", + " <tr>\n", + " <th>3</th>\n", + " <td>disgust</td>\n", + " <td>/content/gdrive/My Drive/CREMA//1079_TSI_DIS_X...</td>\n", + " </tr>\n", + " <tr>\n", + " <th>4</th>\n", + " <td>happy</td>\n", + " <td>/content/gdrive/My Drive/CREMA//1079_TSI_HAP_X...</td>\n", + " </tr>\n", + " </tbody>\n", + "</table>\n", + "</div>\n", + " <div class=\"colab-df-buttons\">\n", + "\n", + " <div class=\"colab-df-container\">\n", + " <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-e283afc1-5e3c-49ce-a91c-3dbc29a1b108')\"\n", + " title=\"Convert this dataframe to an interactive table.\"\n", + " style=\"display:none;\">\n", + "\n", + " <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\" viewBox=\"0 -960 960 960\">\n", + " <path d=\"M120-120v-720h720v720H120Zm60-500h600v-160H180v160Zm220 220h160v-160H400v160Zm0 220h160v-160H400v160ZM180-400h160v-160H180v160Zm440 0h160v-160H620v160ZM180-180h160v-160H180v160Zm440 0h160v-160H620v160Z\"/>\n", + " </svg>\n", + " </button>\n", + "\n", + " <style>\n", + " .colab-df-container {\n", + " display:flex;\n", + " gap: 12px;\n", + " }\n", + "\n", + " .colab-df-convert {\n", + " background-color: #E8F0FE;\n", + " border: none;\n", + " border-radius: 50%;\n", + " cursor: pointer;\n", + " display: none;\n", + " fill: #1967D2;\n", + " height: 32px;\n", + " padding: 0 0 0 0;\n", + " width: 32px;\n", + " }\n", + "\n", + " .colab-df-convert:hover {\n", + " background-color: #E2EBFA;\n", + " box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n", + " fill: #174EA6;\n", + " }\n", + "\n", + " .colab-df-buttons div {\n", + " margin-bottom: 4px;\n", + " }\n", + "\n", + " [theme=dark] .colab-df-convert {\n", + " background-color: #3B4455;\n", + " fill: #D2E3FC;\n", + " }\n", + "\n", + " [theme=dark] .colab-df-convert:hover {\n", + " background-color: #434B5C;\n", + " box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n", + " filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n", + " fill: #FFFFFF;\n", + " }\n", + " </style>\n", + "\n", + " <script>\n", + " const buttonEl =\n", + " document.querySelector('#df-e283afc1-5e3c-49ce-a91c-3dbc29a1b108 button.colab-df-convert');\n", + " buttonEl.style.display =\n", + " google.colab.kernel.accessAllowed ? 'block' : 'none';\n", + "\n", + " async function convertToInteractive(key) {\n", + " const element = document.querySelector('#df-e283afc1-5e3c-49ce-a91c-3dbc29a1b108');\n", + " const dataTable =\n", + " await google.colab.kernel.invokeFunction('convertToInteractive',\n", + " [key], {});\n", + " if (!dataTable) return;\n", + "\n", + " const docLinkHtml = 'Like what you see? Visit the ' +\n", + " '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n", + " + ' to learn more about interactive tables.';\n", + " element.innerHTML = '';\n", + " dataTable['output_type'] = 'display_data';\n", + " await google.colab.output.renderOutput(dataTable, element);\n", + " const docLink = document.createElement('div');\n", + " docLink.innerHTML = docLinkHtml;\n", + " element.appendChild(docLink);\n", + " }\n", + " </script>\n", + " </div>\n", + "\n", + "\n", + "<div id=\"df-4148182a-32c1-4ede-9832-c14c3402fe49\">\n", + " <button class=\"colab-df-quickchart\" onclick=\"quickchart('df-4148182a-32c1-4ede-9832-c14c3402fe49')\"\n", + " title=\"Suggest charts.\"\n", + " style=\"display:none;\">\n", + "\n", + "<svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n", + " width=\"24px\">\n", + " <g>\n", + " <path d=\"M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z\"/>\n", + " </g>\n", + "</svg>\n", + " </button>\n", + "\n", + "<style>\n", + " .colab-df-quickchart {\n", + " background-color: #E8F0FE;\n", + " border: none;\n", + " border-radius: 50%;\n", + " cursor: pointer;\n", + " display: none;\n", + " fill: #1967D2;\n", + " height: 32px;\n", + " padding: 0 0 0 0;\n", + " width: 32px;\n", + " }\n", + "\n", + " .colab-df-quickchart:hover {\n", + " background-color: #E2EBFA;\n", + " box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n", + " fill: #174EA6;\n", + " }\n", + "\n", + " [theme=dark] .colab-df-quickchart {\n", + " background-color: #3B4455;\n", + " fill: #D2E3FC;\n", + " }\n", + "\n", + " [theme=dark] .colab-df-quickchart:hover {\n", + " background-color: #434B5C;\n", + " box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n", + " filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n", + " fill: #FFFFFF;\n", + " }\n", + "</style>\n", + "\n", + " <script>\n", + " async function quickchart(key) {\n", + " const charts = await google.colab.kernel.invokeFunction(\n", + " 'suggestCharts', [key], {});\n", + " }\n", + " (() => {\n", + " let quickchartButtonEl =\n", + " document.querySelector('#df-4148182a-32c1-4ede-9832-c14c3402fe49 button');\n", + " quickchartButtonEl.style.display =\n", + " google.colab.kernel.accessAllowed ? 'block' : 'none';\n", + " })();\n", + " </script>\n", + "</div>\n", + " </div>\n", + " </div>\n" + ] + }, + "metadata": {}, + "execution_count": 7 + } + ], + "source": [ + "emotion_df = []\n", + "\n", + "for wav in os.listdir(Crema):\n", + " info = wav.partition(\".wav\")[0].split(\"_\")\n", + " if (len(info)<3):\n", + " continue;\n", + " if info[2] == 'SAD':\n", + " emotion_df.append((\"sad\", Crema + \"/\" + wav))\n", + " elif info[2] == 'ANG':\n", + " emotion_df.append((\"angry\", Crema + \"/\" + wav))\n", + " elif info[2] == 'DIS':\n", + " emotion_df.append((\"disgust\", Crema + \"/\" + wav))\n", + " elif info[2] == 'FEA':\n", + " emotion_df.append((\"fear\", Crema + \"/\" + wav))\n", + " elif info[2] == 'HAP':\n", + " emotion_df.append((\"happy\", Crema + \"/\" + wav))\n", + " elif info[2] == 'NEU':\n", + " emotion_df.append((\"neutral\", Crema + \"/\" + wav))\n", + "\n", + "\n", + "Crema_df = pd.DataFrame.from_dict(emotion_df)\n", + "Crema_df.rename(columns={1 : \"Path\", 0 : \"Emotion\"}, inplace=True)\n", + "\n", + "Crema_df.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "XhKkVexzA46N" + }, + "source": [ + "## Preprocessing\n", + "\n", + "The audio files we want to use are in .wav format. However, an ML model works on numerical data. So we need to perform some preprocessing operations to extract numerical features from the audios and transform these features to a more suitable form. This will improve the performance of our model." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-sGwDJTKSoax" + }, + "source": [ + "### Data Augmentation\n", + "\n", + "This is the process of transforming existing data in various ways to generate more samples and increase model robustness. We make multiple versions of the same data item but with some differences. This allows the model to recognize a wider variety of data and reduce overfitting. We have performed the following data augmentation techniques:\n", + "* **Noise injection**: Adds a random factor to all data items to provide some noise.\n", + "* **Stretching**: Alters the speed of an audio, simulating variations in speech rate or tempo.\n", + "* **Pitch Shifting**: Changes the pitch of an audio, depicting variations of speaker characteristics or musical notes." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "KYBjIdsfRyf_" + }, + "outputs": [], + "source": [ + "def noise(data):\n", + " noise_amp = 0.035 * np.random.uniform() * np.amax(data)\n", + " data = data + noise_amp * np.random.normal(size = data.shape[0])\n", + " return data\n", + "\n", + "def stretch(data, rate = 0.8):\n", + " return librosa.effects.time_stretch(data, rate = rate)\n", + "\n", + "def pitch(data, sampling_rate, pitch_factor = 0.7):\n", + " return librosa.effects.pitch_shift(data, sr = sampling_rate, n_steps = pitch_factor)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "IqOVow1ROH4p" + }, + "source": [ + "### Feature Extraction\n", + "\n", + "We need to extract some numerical features from the audios to feed our ML model. The [Librosa](https://librosa.org/doc/latest/index.html) library allows us to do this easily.\n", + "\n", + "First, we need to understand what a **mel scale** is. It is a scale of pitches that is based on the way humans perceive and discriminate between different frequencies of sound. Now, let us discuss the features we will extract from the audio:\n", + "\n", + "* **Zero Crossing Rate (ZCR)**: Measures how often the sound changes it's sign (positive or negative) over time.\n", + "* **Chroma Short-Time Fourier Transform (STFT)**: Breaks down the audio signal into small segments (frames) and calculates the Fourier Transform for each frame, resulting in a time-frequency representation of the signal.\n", + "* **Mel-Frequency Cepstral Coefficients (MFCC)**: A set of coefficients derived from the mel spectrogram\n", + "* **Melspectogram**: A visual representation of the frequency content of an audio signal mapped on the mel scale.\n", + "* **Root Mean Square**: Provides the Root Mean Square value for each frame, which is a measure of the amplitude or energy of a sound signal.\n", + "\n", + "You can read more about all the features we can extract using the Librosa library [here](https://librosa.org/doc/latest/feature.html)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "vHuzjpIfR2an" + }, + "outputs": [], + "source": [ + "def extract_features(data, sample_rate):\n", + " # ZCR\n", + " result = np.array([])\n", + " zcr = np.mean(librosa.feature.zero_crossing_rate(y=data).T, axis=0)\n", + " result=np.hstack((result, zcr)) # stacking horizontally\n", + "\n", + " # Chroma STFT\n", + " stft = np.abs(librosa.stft(data))\n", + " chroma_stft = np.mean(librosa.feature.chroma_stft(S=stft, sr=sample_rate).T, axis=0)\n", + " result = np.hstack((result, chroma_stft)) # stacking horizontally\n", + "\n", + " # MFCC\n", + " mfcc = np.mean(librosa.feature.mfcc(y=data, sr=sample_rate).T, axis=0)\n", + " result = np.hstack((result, mfcc)) # stacking horizontally\n", + "\n", + " # Root Mean Square\n", + " rms = np.mean(librosa.feature.rms(y=data).T, axis=0)\n", + " result = np.hstack((result, rms)) # stacking horizontally\n", + "\n", + " # Melspectogram\n", + " mel = np.mean(librosa.feature.melspectrogram(y=data, sr=sample_rate).T, axis=0)\n", + " result = np.hstack((result, mel)) # stacking horizontally\n", + "\n", + " return result" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "4eo-DgBiAFH4" + }, + "source": [ + "The function below is used to extract the features from the audio stored at a path. Then it applies the data augmentation techniques we defined previously, and extracts features for each augmented data too. This gives us three versions of a data item:\n", + "* Normal features\n", + "* Features from data with noise\n", + "* Features from time stretched and pitch shifted data\n", + "\n", + "These are added into our final dataset as individual samples." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "rksrNPA4R8v9" + }, + "outputs": [], + "source": [ + "def get_features(path):\n", + " data, sample_rate = librosa.load(path, duration=2.5, offset=0.6)\n", + "\n", + " # without augmentation\n", + " normal_features = extract_features(data, sample_rate)\n", + " result = np.array(normal_features)\n", + "\n", + " # data with noise\n", + " noise_data = noise(data)\n", + " noise_features = extract_features(noise_data, sample_rate)\n", + " result = np.vstack((result, noise_features)) # stacking vertically\n", + "\n", + " # data with stretching and pitching\n", + " stretch_data = stretch(data)\n", + " stretch_pitch_data = pitch(stretch_data, sample_rate)\n", + " stretch_pitch_features = extract_features(stretch_pitch_data, sample_rate)\n", + " result = np.vstack((result, stretch_pitch_features)) # stacking vertically\n", + "\n", + " return result" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "iIt_EokpBdHc" + }, + "source": [ + "Now we will iterate through the Crema_df DataFrame containing the path and emotion of each audio sample. We will extract features for each audio's three versions, add it to X, and add the corresponding emotion to Y." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "81rZsmb4SBCH", + "outputId": "a8b14400-7aee-434d-c9a9-236db687d112" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "/usr/local/lib/python3.10/dist-packages/librosa/core/pitch.py:101: UserWarning: Trying to estimate tuning from empty frequency set.\n", + " return pitch_tuning(\n" + ] + } + ], + "source": [ + "X, Y = [], []\n", + "for path, emotion in zip(Crema_df.Path, Crema_df.Emotion):\n", + " feature = get_features(path)\n", + " for ele in feature:\n", + " X.append(ele)\n", + " Y.append(emotion)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Xa8RsEiYB-ru" + }, + "source": [ + "Here we have made a DataFrame using the lists X and Y." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 342 + }, + "id": "go4lU_2VSCxy", + "outputId": "05734eff-da43-4396-8c2e-c4c80ec2b0be" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + " 0 1 2 3 4 5 6 \\\n", + "0 0.051835 0.552957 0.564289 0.512976 0.518041 0.528111 0.501150 \n", + "1 0.081790 0.611068 0.619012 0.578897 0.580346 0.604983 0.552418 \n", + "2 0.054339 0.525215 0.525026 0.478083 0.526773 0.554233 0.521426 \n", + "3 0.050157 0.514931 0.591693 0.464526 0.429137 0.480203 0.572344 \n", + "4 0.098122 0.606869 0.680955 0.572593 0.548943 0.581684 0.626757 \n", + "\n", + " 7 8 9 ... 153 154 \\\n", + "0 0.550490 0.673705 0.744412 ... 2.713831e-09 2.560777e-09 \n", + "1 0.557888 0.677792 0.749837 ... 8.333886e-05 7.936021e-05 \n", + "2 0.558976 0.671527 0.739728 ... 3.503047e-09 3.054322e-09 \n", + "3 0.722630 0.699706 0.676802 ... 3.512564e-09 3.153377e-09 \n", + "4 0.754920 0.735712 0.713573 ... 1.368801e-04 1.329551e-04 \n", + "\n", + " 155 156 157 158 159 \\\n", + "0 2.451516e-09 2.369350e-09 2.308000e-09 2.264365e-09 2.232698e-09 \n", + "1 7.905496e-05 8.138233e-05 7.764955e-05 7.412745e-05 7.555283e-05 \n", + "2 2.943538e-09 2.634693e-09 2.343703e-09 2.368675e-09 2.363831e-09 \n", + "3 2.901090e-09 2.715085e-09 2.576861e-09 2.476340e-09 2.403195e-09 \n", + "4 1.397343e-04 1.433890e-04 1.408767e-04 1.354171e-04 1.373235e-04 \n", + "\n", + " 160 161 labels \n", + "0 2.212761e-09 2.200083e-09 neutral \n", + "1 8.043366e-05 8.144332e-05 neutral \n", + "2 1.876258e-09 6.538760e-10 neutral \n", + "3 2.354688e-09 2.325111e-09 sad \n", + "4 1.433754e-04 1.442893e-04 sad \n", + "\n", + "[5 rows x 163 columns]" + ], + "text/html": [ + "\n", + " <div id=\"df-d6ec0952-463b-4721-8872-88b40c464ae6\" class=\"colab-df-container\">\n", + " <div>\n", + "<style scoped>\n", + " .dataframe tbody tr th:only-of-type {\n", + " vertical-align: middle;\n", + " }\n", + "\n", + " .dataframe tbody tr th {\n", + " vertical-align: top;\n", + " }\n", + "\n", + " .dataframe thead th {\n", + " text-align: right;\n", + " }\n", + "</style>\n", + "<table border=\"1\" class=\"dataframe\">\n", + " <thead>\n", + " <tr style=\"text-align: right;\">\n", + " <th></th>\n", + " <th>0</th>\n", + " <th>1</th>\n", + " <th>2</th>\n", + " <th>3</th>\n", + " <th>4</th>\n", + " <th>5</th>\n", + " <th>6</th>\n", + " <th>7</th>\n", + " <th>8</th>\n", + " <th>9</th>\n", + " <th>...</th>\n", + " <th>153</th>\n", + " <th>154</th>\n", + " <th>155</th>\n", + " <th>156</th>\n", + " <th>157</th>\n", + " <th>158</th>\n", + " <th>159</th>\n", + " <th>160</th>\n", + " <th>161</th>\n", + " <th>labels</th>\n", + " </tr>\n", + " </thead>\n", + " <tbody>\n", + " <tr>\n", + " <th>0</th>\n", + " <td>0.051835</td>\n", + " <td>0.552957</td>\n", + " <td>0.564289</td>\n", + " <td>0.512976</td>\n", + " <td>0.518041</td>\n", + " <td>0.528111</td>\n", + " <td>0.501150</td>\n", + " <td>0.550490</td>\n", + " <td>0.673705</td>\n", + " <td>0.744412</td>\n", + " <td>...</td>\n", + " <td>2.713831e-09</td>\n", + " <td>2.560777e-09</td>\n", + " <td>2.451516e-09</td>\n", + " <td>2.369350e-09</td>\n", + " <td>2.308000e-09</td>\n", + " <td>2.264365e-09</td>\n", + " <td>2.232698e-09</td>\n", + " <td>2.212761e-09</td>\n", + " <td>2.200083e-09</td>\n", + " <td>neutral</td>\n", + " </tr>\n", + " <tr>\n", + " <th>1</th>\n", + " <td>0.081790</td>\n", + " <td>0.611068</td>\n", + " <td>0.619012</td>\n", + " <td>0.578897</td>\n", + " <td>0.580346</td>\n", + " <td>0.604983</td>\n", + " <td>0.552418</td>\n", + " <td>0.557888</td>\n", + " <td>0.677792</td>\n", + " <td>0.749837</td>\n", + " <td>...</td>\n", + " <td>8.333886e-05</td>\n", + " <td>7.936021e-05</td>\n", + " <td>7.905496e-05</td>\n", + " <td>8.138233e-05</td>\n", + " <td>7.764955e-05</td>\n", + " <td>7.412745e-05</td>\n", + " <td>7.555283e-05</td>\n", + " <td>8.043366e-05</td>\n", + " <td>8.144332e-05</td>\n", + " <td>neutral</td>\n", + " </tr>\n", + " <tr>\n", + " <th>2</th>\n", + " <td>0.054339</td>\n", + " <td>0.525215</td>\n", + " <td>0.525026</td>\n", + " <td>0.478083</td>\n", + " <td>0.526773</td>\n", + " <td>0.554233</td>\n", + " <td>0.521426</td>\n", + " <td>0.558976</td>\n", + " <td>0.671527</td>\n", + " <td>0.739728</td>\n", + " <td>...</td>\n", + " <td>3.503047e-09</td>\n", + " <td>3.054322e-09</td>\n", + " <td>2.943538e-09</td>\n", + " <td>2.634693e-09</td>\n", + " <td>2.343703e-09</td>\n", + " <td>2.368675e-09</td>\n", + " <td>2.363831e-09</td>\n", + " <td>1.876258e-09</td>\n", + " <td>6.538760e-10</td>\n", + " <td>neutral</td>\n", + " </tr>\n", + " <tr>\n", + " <th>3</th>\n", + " <td>0.050157</td>\n", + " <td>0.514931</td>\n", + " <td>0.591693</td>\n", + " <td>0.464526</td>\n", + " <td>0.429137</td>\n", + " <td>0.480203</td>\n", + " <td>0.572344</td>\n", + " <td>0.722630</td>\n", + " <td>0.699706</td>\n", + " <td>0.676802</td>\n", + " <td>...</td>\n", + " <td>3.512564e-09</td>\n", + " <td>3.153377e-09</td>\n", + " <td>2.901090e-09</td>\n", + " <td>2.715085e-09</td>\n", + " <td>2.576861e-09</td>\n", + " <td>2.476340e-09</td>\n", + " <td>2.403195e-09</td>\n", + " <td>2.354688e-09</td>\n", + " <td>2.325111e-09</td>\n", + " <td>sad</td>\n", + " </tr>\n", + " <tr>\n", + " <th>4</th>\n", + " <td>0.098122</td>\n", + " <td>0.606869</td>\n", + " <td>0.680955</td>\n", + " <td>0.572593</td>\n", + " <td>0.548943</td>\n", + " <td>0.581684</td>\n", + " <td>0.626757</td>\n", + " <td>0.754920</td>\n", + " <td>0.735712</td>\n", + " <td>0.713573</td>\n", + " <td>...</td>\n", + " <td>1.368801e-04</td>\n", + " <td>1.329551e-04</td>\n", + " <td>1.397343e-04</td>\n", + " <td>1.433890e-04</td>\n", + " <td>1.408767e-04</td>\n", + " <td>1.354171e-04</td>\n", + " <td>1.373235e-04</td>\n", + " <td>1.433754e-04</td>\n", + " <td>1.442893e-04</td>\n", + " <td>sad</td>\n", + " </tr>\n", + " </tbody>\n", + "</table>\n", + "<p>5 rows × 163 columns</p>\n", + "</div>\n", + " <div class=\"colab-df-buttons\">\n", + "\n", + " <div class=\"colab-df-container\">\n", + " <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-d6ec0952-463b-4721-8872-88b40c464ae6')\"\n", + " title=\"Convert this dataframe to an interactive table.\"\n", + " style=\"display:none;\">\n", + "\n", + " <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\" viewBox=\"0 -960 960 960\">\n", + " <path d=\"M120-120v-720h720v720H120Zm60-500h600v-160H180v160Zm220 220h160v-160H400v160Zm0 220h160v-160H400v160ZM180-400h160v-160H180v160Zm440 0h160v-160H620v160ZM180-180h160v-160H180v160Zm440 0h160v-160H620v160Z\"/>\n", + " </svg>\n", + " </button>\n", + "\n", + " <style>\n", + " .colab-df-container {\n", + " display:flex;\n", + " gap: 12px;\n", + " }\n", + "\n", + " .colab-df-convert {\n", + " background-color: #E8F0FE;\n", + " border: none;\n", + " border-radius: 50%;\n", + " cursor: pointer;\n", + " display: none;\n", + " fill: #1967D2;\n", + " height: 32px;\n", + " padding: 0 0 0 0;\n", + " width: 32px;\n", + " }\n", + "\n", + " .colab-df-convert:hover {\n", + " background-color: #E2EBFA;\n", + " box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n", + " fill: #174EA6;\n", + " }\n", + "\n", + " .colab-df-buttons div {\n", + " margin-bottom: 4px;\n", + " }\n", + "\n", + " [theme=dark] .colab-df-convert {\n", + " background-color: #3B4455;\n", + " fill: #D2E3FC;\n", + " }\n", + "\n", + " [theme=dark] .colab-df-convert:hover {\n", + " background-color: #434B5C;\n", + " box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n", + " filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n", + " fill: #FFFFFF;\n", + " }\n", + " </style>\n", + "\n", + " <script>\n", + " const buttonEl =\n", + " document.querySelector('#df-d6ec0952-463b-4721-8872-88b40c464ae6 button.colab-df-convert');\n", + " buttonEl.style.display =\n", + " google.colab.kernel.accessAllowed ? 'block' : 'none';\n", + "\n", + " async function convertToInteractive(key) {\n", + " const element = document.querySelector('#df-d6ec0952-463b-4721-8872-88b40c464ae6');\n", + " const dataTable =\n", + " await google.colab.kernel.invokeFunction('convertToInteractive',\n", + " [key], {});\n", + " if (!dataTable) return;\n", + "\n", + " const docLinkHtml = 'Like what you see? Visit the ' +\n", + " '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n", + " + ' to learn more about interactive tables.';\n", + " element.innerHTML = '';\n", + " dataTable['output_type'] = 'display_data';\n", + " await google.colab.output.renderOutput(dataTable, element);\n", + " const docLink = document.createElement('div');\n", + " docLink.innerHTML = docLinkHtml;\n", + " element.appendChild(docLink);\n", + " }\n", + " </script>\n", + " </div>\n", + "\n", + "\n", + "<div id=\"df-7c1015b9-2794-4a8e-9703-6ded718f55eb\">\n", + " <button class=\"colab-df-quickchart\" onclick=\"quickchart('df-7c1015b9-2794-4a8e-9703-6ded718f55eb')\"\n", + " title=\"Suggest charts.\"\n", + " style=\"display:none;\">\n", + "\n", + "<svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n", + " width=\"24px\">\n", + " <g>\n", + " <path d=\"M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z\"/>\n", + " </g>\n", + "</svg>\n", + " </button>\n", + "\n", + "<style>\n", + " .colab-df-quickchart {\n", + " background-color: #E8F0FE;\n", + " border: none;\n", + " border-radius: 50%;\n", + " cursor: pointer;\n", + " display: none;\n", + " fill: #1967D2;\n", + " height: 32px;\n", + " padding: 0 0 0 0;\n", + " width: 32px;\n", + " }\n", + "\n", + " .colab-df-quickchart:hover {\n", + " background-color: #E2EBFA;\n", + " box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n", + " fill: #174EA6;\n", + " }\n", + "\n", + " [theme=dark] .colab-df-quickchart {\n", + " background-color: #3B4455;\n", + " fill: #D2E3FC;\n", + " }\n", + "\n", + " [theme=dark] .colab-df-quickchart:hover {\n", + " background-color: #434B5C;\n", + " box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n", + " filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n", + " fill: #FFFFFF;\n", + " }\n", + "</style>\n", + "\n", + " <script>\n", + " async function quickchart(key) {\n", + " const charts = await google.colab.kernel.invokeFunction(\n", + " 'suggestCharts', [key], {});\n", + " }\n", + " (() => {\n", + " let quickchartButtonEl =\n", + " document.querySelector('#df-7c1015b9-2794-4a8e-9703-6ded718f55eb button');\n", + " quickchartButtonEl.style.display =\n", + " google.colab.kernel.accessAllowed ? 'block' : 'none';\n", + " })();\n", + " </script>\n", + "</div>\n", + " </div>\n", + " </div>\n" + ] + }, + "metadata": {}, + "execution_count": 12 + } + ], + "source": [ + "Features = pd.DataFrame(X)\n", + "Features['labels'] = Y\n", + "Features.to_csv('features.csv', index=False)\n", + "Features.head()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "uk4y87eECD4F" + }, + "source": [ + "The X and Y datasets are separated here. X stores the features of audio samples while Y stores the corresponding labels." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "EUObnydOSEXr" + }, + "outputs": [], + "source": [ + "X = Features.iloc[: ,:-1].values\n", + "Y = Features['labels'].values" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "B92B4YVoCPed" + }, + "source": [ + "The [pad sequences](https://www.tensorflow.org/api_docs/python/tf/keras/utils/pad_sequences) function is used to pad the input data to the same length, to ensure that all samples have the same shape." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "s_yiO3ZsrLjL" + }, + "outputs": [], + "source": [ + "X = tf.keras.utils.pad_sequences(X)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "AvHBQsARUeNb" + }, + "source": [ + "Scikit Learn's [OneHotEncoder](https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.OneHotEncoder.html) is used to convert categorical labels into numerical data. It creates a column in the labels dataset for each category, which contains only binary data. For example, if we have the following categories:\n", + "\n", + "`[Anger, Disgust, Fear, Happy, Neutral, Sad]`\n", + "\n", + "And a specific audio belongs to 'Anger' category, then the OneHotEncoder will transform it to:\n", + "\n", + "`[1, 0, 0, 0, 0, 0]`\n", + "\n", + "Please note that the order of which column represents which category may differ." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "OT2HLn0HSF_3" + }, + "outputs": [], + "source": [ + "encoder = OneHotEncoder()\n", + "Y = encoder.fit_transform(np.array(Y).reshape(-1,1)).toarray()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "PZYKuz7FUfyj" + }, + "source": [ + "Splitting into train/test splits" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "snRl0fNdSJnv", + "outputId": "faf825f6-a7f6-4cd5-c3da-60bafaafddc2" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "((16744, 162), (16744, 6), (5582, 162), (5582, 6))" + ] + }, + "metadata": {}, + "execution_count": 16 + } + ], + "source": [ + "x_train, x_test, y_train, y_test = train_test_split(X, Y, random_state=0, shuffle=True)\n", + "x_train.shape, y_train.shape, x_test.shape, y_test.shape" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "IHoJ7egnUkrf" + }, + "source": [ + "Now we will scale the data and split it into training and testing sets.\n", + "* [Scaling](https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.scale.html) is done to make all numerical data have similar magnitudes. This makes computations easier.\n", + "* The training sets are used to train the model.\n", + "* The testing sets are used to test the model's accuracy." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "l0Jg7MN4SLls", + "outputId": "071cc13e-c60a-4c9b-b81d-ae02a1f67ae3" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "((16744, 162), (16744, 6), (5582, 162), (5582, 6))" + ] + }, + "metadata": {}, + "execution_count": 17 + } + ], + "source": [ + "scaler = StandardScaler()\n", + "x_train = scaler.fit_transform(x_train)\n", + "x_test = scaler.transform(x_test)\n", + "x_train.shape, y_train.shape, x_test.shape, y_test.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "HVEcjoXSrF_i", + "outputId": "ef9101b8-7af0-4733-8c85-e58038b5263d" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "((22326, 162), (16744, 162), (5582, 162))" + ] + }, + "metadata": {}, + "execution_count": 18 + } + ], + "source": [ + "X.shape, x_train.shape, x_test.shape" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "lbJnHZExUmOY" + }, + "source": [ + "We will use a 1D Convolutional layer in our model, and for that, our input data needs to be a a 3D tensor with dimensions `(batch_size, time_steps, input_dim)`. So we will expand the dimensions of our X_train and X_test datasets. The extra 1 in the shape depicts that our data is 1 dimensional." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "DfHhUoUFSMFL", + "outputId": "2e1a9134-e1f7-4826-a9ad-3f74ebe12849" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "((16744, 162, 1), (16744, 6), (5582, 162, 1), (5582, 6))" + ] + }, + "metadata": {}, + "execution_count": 19 + } + ], + "source": [ + "x_train = np.expand_dims(x_train, axis=2)\n", + "x_test = np.expand_dims(x_test, axis=2)\n", + "x_train.shape, y_train.shape, x_test.shape, y_test.shape" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "TAnQsHL9Uphd" + }, + "source": [ + "### Training the model\n", + "We will build a sequential model to classify speech emotions using TensorFlow and Keras. Here is an overview of the layers used:\n", + "* **Conv1D**: Applies a set of filters to capture patterns in sequential data like time series or audio, enabling feature extraction through sliding convolutions.\n", + "* **Activation**: Introduces non-linearity by applying an element-wise activation function to the input, enhancing the network's learning capacity.\n", + "* **BatchNormalization**: Normalizes input activations within a mini-batch, accelerating training by stabilizing and improving gradient flow.\n", + "* **Dropout**: Randomly deactivates a fraction of neurons during training, reducing overfitting by promoting generalization.\n", + "* **MaxPooling1D**: Downsamples the input by retaining the maximum value in each local region, reducing computation.\n", + "* **Flatten**: Reshapes input data from a multidimensional format into a 1D vector, suitable for fully connected layers.\n", + "* **Dense**: Connects each neuron to every neuron in the previous layer, allowing complex relationships to be learned during training.\n", + "\n", + "In the end, we need probabilities for each of the 6 classes of emotions, so we need 6 outputs. This is why the last Dense layer returns an array of size 6.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "uFb71_ga5RoB", + "outputId": "cba3e5ba-f17e-4eb7-c5fd-0adc4ed060c9" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Model: \"sequential\"\n", + "_________________________________________________________________\n", + " Layer (type) Output Shape Param # \n", + "=================================================================\n", + " conv1d (Conv1D) (None, 162, 256) 1792 \n", + " \n", + " activation (Activation) (None, 162, 256) 0 \n", + " \n", + " conv1d_1 (Conv1D) (None, 162, 256) 393472 \n", + " \n", + " batch_normalization (BatchN (None, 162, 256) 1024 \n", + " ormalization) \n", + " \n", + " activation_1 (Activation) (None, 162, 256) 0 \n", + " \n", + " dropout (Dropout) (None, 162, 256) 0 \n", + " \n", + " max_pooling1d (MaxPooling1D (None, 20, 256) 0 \n", + " ) \n", + " \n", + " conv1d_2 (Conv1D) (None, 20, 128) 196736 \n", + " \n", + " activation_2 (Activation) (None, 20, 128) 0 \n", + " \n", + " conv1d_3 (Conv1D) (None, 20, 128) 98432 \n", + " \n", + " activation_3 (Activation) (None, 20, 128) 0 \n", + " \n", + " dropout_1 (Dropout) (None, 20, 128) 0 \n", + " \n", + " conv1d_4 (Conv1D) (None, 20, 128) 98432 \n", + " \n", + " activation_4 (Activation) (None, 20, 128) 0 \n", + " \n", + " conv1d_5 (Conv1D) (None, 20, 128) 98432 \n", + " \n", + " batch_normalization_1 (Batc (None, 20, 128) 512 \n", + " hNormalization) \n", + " \n", + " activation_5 (Activation) (None, 20, 128) 0 \n", + " \n", + " dropout_2 (Dropout) (None, 20, 128) 0 \n", + " \n", + " max_pooling1d_1 (MaxPooling (None, 2, 128) 0 \n", + " 1D) \n", + " \n", + " conv1d_6 (Conv1D) (None, 2, 64) 49216 \n", + " \n", + " activation_6 (Activation) (None, 2, 64) 0 \n", + " \n", + " conv1d_7 (Conv1D) (None, 2, 64) 24640 \n", + " \n", + " activation_7 (Activation) (None, 2, 64) 0 \n", + " \n", + " dropout_3 (Dropout) (None, 2, 64) 0 \n", + " \n", + " flatten (Flatten) (None, 128) 0 \n", + " \n", + " dense (Dense) (None, 6) 774 \n", + " \n", + " activation_8 (Activation) (None, 6) 0 \n", + " \n", + "=================================================================\n", + "Total params: 963,462\n", + "Trainable params: 962,694\n", + "Non-trainable params: 768\n", + "_________________________________________________________________\n" + ] + } + ], + "source": [ + "model = Sequential()\n", + "model.add(layers.Conv1D(256, 6, padding='same',input_shape=(x_train.shape[1],1)))\n", + "model.add(layers.Activation('relu'))\n", + "model.add(layers.Conv1D(256, 6, padding='same'))\n", + "model.add(layers.BatchNormalization())\n", + "model.add(layers.Activation('relu'))\n", + "model.add(layers.Dropout(0.2))\n", + "model.add(layers.MaxPooling1D(pool_size=(8)))\n", + "model.add(layers.Conv1D(128, 6, padding='same'))\n", + "model.add(layers.Activation('relu'))\n", + "model.add(layers.Conv1D(128, 6, padding='same'))\n", + "model.add(layers.Activation('relu'))\n", + "model.add(layers.Dropout(0.2))\n", + "model.add(layers.Conv1D(128, 6, padding='same'))\n", + "model.add(layers.Activation('relu'))\n", + "model.add(layers.Conv1D(128, 6, padding='same'))\n", + "model.add(layers.BatchNormalization())\n", + "model.add(layers.Activation('relu'))\n", + "model.add(layers.Dropout(0.2))\n", + "model.add(layers.MaxPooling1D(pool_size=(8)))\n", + "model.add(layers.Conv1D(64, 6, padding='same'))\n", + "model.add(layers.Activation('relu'))\n", + "model.add(layers.Conv1D(64, 6, padding='same'))\n", + "model.add(layers.Activation('relu'))\n", + "model.add(layers.Dropout(0.2))\n", + "model.add(layers.Flatten())\n", + "model.add(layers.Dense(6))\n", + "model.add(layers.Activation('softmax'))\n", + "opt = keras.optimizers.Adam(learning_rate=0.0001)\n", + "model.summary()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "g68nijvIF2Ca" + }, + "source": [ + "Now we will compile the model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "qWj3H05Q6pm_" + }, + "outputs": [], + "source": [ + "model.compile(loss='categorical_crossentropy', optimizer=opt, metrics=['accuracy'])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "p78y1ESDF5KQ" + }, + "source": [ + "Next, we will train our model. [ReduceLROnPlateau](https://keras.io/api/callbacks/reduce_lr_on_plateau/) is used to reduce the learning rate when the loss has stopped improving. [EarlyStopping](https://keras.io/api/callbacks/early_stopping/) monitors the val_loss and stops the training process when it doesn't improve." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "H1F2haOv6pvZ", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "f2f2bb9f-e445-4713-878c-82da952454fc" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Epoch 1/100\n", + "1047/1047 [==============================] - 29s 13ms/step - loss: 1.5803 - accuracy: 0.3272 - val_loss: 1.5216 - val_accuracy: 0.3739 - lr: 1.0000e-04\n", + "Epoch 2/100\n", + "1047/1047 [==============================] - 12s 12ms/step - loss: 1.4870 - accuracy: 0.3807 - val_loss: 1.5065 - val_accuracy: 0.3884 - lr: 1.0000e-04\n", + "Epoch 3/100\n", + "1047/1047 [==============================] - 12s 11ms/step - loss: 1.4541 - accuracy: 0.3952 - val_loss: 1.4635 - val_accuracy: 0.3954 - lr: 1.0000e-04\n", + "Epoch 4/100\n", + "1047/1047 [==============================] - 12s 12ms/step - loss: 1.4253 - accuracy: 0.4105 - val_loss: 1.4341 - val_accuracy: 0.4282 - lr: 1.0000e-04\n", + "Epoch 5/100\n", + "1047/1047 [==============================] - 13s 12ms/step - loss: 1.4092 - accuracy: 0.4199 - val_loss: 1.4595 - val_accuracy: 0.4077 - lr: 1.0000e-04\n", + "Epoch 6/100\n", + "1047/1047 [==============================] - 13s 12ms/step - loss: 1.3890 - accuracy: 0.4299 - val_loss: 1.4032 - val_accuracy: 0.4317 - lr: 1.0000e-04\n", + "Epoch 7/100\n", + "1047/1047 [==============================] - 12s 12ms/step - loss: 1.3709 - accuracy: 0.4471 - val_loss: 1.3958 - val_accuracy: 0.4294 - lr: 1.0000e-04\n", + "Epoch 8/100\n", + "1047/1047 [==============================] - 12s 11ms/step - loss: 1.3613 - accuracy: 0.4482 - val_loss: 1.4311 - val_accuracy: 0.4102 - lr: 1.0000e-04\n", + "Epoch 9/100\n", + "1047/1047 [==============================] - 12s 12ms/step - loss: 1.3420 - accuracy: 0.4563 - val_loss: 1.3901 - val_accuracy: 0.4409 - lr: 1.0000e-04\n", + "Epoch 10/100\n", + "1047/1047 [==============================] - 13s 12ms/step - loss: 1.3292 - accuracy: 0.4643 - val_loss: 1.3893 - val_accuracy: 0.4434 - lr: 1.0000e-04\n", + "Epoch 11/100\n", + "1047/1047 [==============================] - 12s 12ms/step - loss: 1.3162 - accuracy: 0.4689 - val_loss: 1.3742 - val_accuracy: 0.4482 - lr: 1.0000e-04\n", + "Epoch 12/100\n", + "1047/1047 [==============================] - 13s 12ms/step - loss: 1.3033 - accuracy: 0.4738 - val_loss: 1.3821 - val_accuracy: 0.4507 - lr: 1.0000e-04\n", + "Epoch 13/100\n", + "1047/1047 [==============================] - 12s 11ms/step - loss: 1.2889 - accuracy: 0.4833 - val_loss: 1.3452 - val_accuracy: 0.4609 - lr: 1.0000e-04\n", + "Epoch 14/100\n", + "1047/1047 [==============================] - 12s 11ms/step - loss: 1.2715 - accuracy: 0.4933 - val_loss: 1.3690 - val_accuracy: 0.4559 - lr: 1.0000e-04\n", + "Epoch 15/100\n", + "1047/1047 [==============================] - 12s 12ms/step - loss: 1.2642 - accuracy: 0.4916 - val_loss: 1.3460 - val_accuracy: 0.4618 - lr: 1.0000e-04\n", + "Epoch 16/100\n", + "1047/1047 [==============================] - 13s 12ms/step - loss: 1.2439 - accuracy: 0.5028 - val_loss: 1.3293 - val_accuracy: 0.4719 - lr: 1.0000e-04\n", + "Epoch 17/100\n", + "1047/1047 [==============================] - 12s 12ms/step - loss: 1.2287 - accuracy: 0.5073 - val_loss: 1.3309 - val_accuracy: 0.4663 - lr: 1.0000e-04\n", + "Epoch 18/100\n", + "1047/1047 [==============================] - 12s 12ms/step - loss: 1.2193 - accuracy: 0.5122 - val_loss: 1.3353 - val_accuracy: 0.4686 - lr: 1.0000e-04\n", + "Epoch 19/100\n", + "1047/1047 [==============================] - 13s 13ms/step - loss: 1.2044 - accuracy: 0.5237 - val_loss: 1.3370 - val_accuracy: 0.4636 - lr: 1.0000e-04\n", + "Epoch 20/100\n", + "1047/1047 [==============================] - 13s 12ms/step - loss: 1.1869 - accuracy: 0.5258 - val_loss: 1.3021 - val_accuracy: 0.4805 - lr: 1.0000e-04\n", + "Epoch 21/100\n", + "1047/1047 [==============================] - 13s 12ms/step - loss: 1.1744 - accuracy: 0.5288 - val_loss: 1.3028 - val_accuracy: 0.4807 - lr: 1.0000e-04\n", + "Epoch 22/100\n", + "1047/1047 [==============================] - 13s 12ms/step - loss: 1.1574 - accuracy: 0.5376 - val_loss: 1.3189 - val_accuracy: 0.4717 - lr: 1.0000e-04\n", + "Epoch 23/100\n", + "1047/1047 [==============================] - 13s 12ms/step - loss: 1.1437 - accuracy: 0.5492 - val_loss: 1.3197 - val_accuracy: 0.4694 - lr: 1.0000e-04\n", + "Epoch 24/100\n", + "1047/1047 [==============================] - 13s 12ms/step - loss: 1.1234 - accuracy: 0.5563 - val_loss: 1.3482 - val_accuracy: 0.4678 - lr: 1.0000e-04\n", + "Epoch 25/100\n", + "1047/1047 [==============================] - 13s 12ms/step - loss: 1.1152 - accuracy: 0.5568 - val_loss: 1.3050 - val_accuracy: 0.4821 - lr: 1.0000e-04\n", + "Epoch 26/100\n", + "1047/1047 [==============================] - 12s 11ms/step - loss: 1.1022 - accuracy: 0.5677 - val_loss: 1.2853 - val_accuracy: 0.4867 - lr: 1.0000e-04\n", + "Epoch 27/100\n", + "1047/1047 [==============================] - 12s 12ms/step - loss: 1.0844 - accuracy: 0.5678 - val_loss: 1.2719 - val_accuracy: 0.4925 - lr: 1.0000e-04\n", + "Epoch 28/100\n", + "1047/1047 [==============================] - 12s 12ms/step - loss: 1.0644 - accuracy: 0.5828 - val_loss: 1.2978 - val_accuracy: 0.4798 - lr: 1.0000e-04\n", + "Epoch 29/100\n", + "1047/1047 [==============================] - 14s 14ms/step - loss: 1.0524 - accuracy: 0.5882 - val_loss: 1.2986 - val_accuracy: 0.4844 - lr: 1.0000e-04\n", + "Epoch 30/100\n", + "1047/1047 [==============================] - 13s 12ms/step - loss: 1.0364 - accuracy: 0.5920 - val_loss: 1.2919 - val_accuracy: 0.4894 - lr: 1.0000e-04\n", + "Epoch 31/100\n", + "1047/1047 [==============================] - 12s 12ms/step - loss: 1.0160 - accuracy: 0.6043 - val_loss: 1.2651 - val_accuracy: 0.4937 - lr: 1.0000e-04\n", + "Epoch 32/100\n", + "1047/1047 [==============================] - 12s 11ms/step - loss: 1.0056 - accuracy: 0.6058 - val_loss: 1.2905 - val_accuracy: 0.4841 - lr: 1.0000e-04\n", + "Epoch 33/100\n", + "1047/1047 [==============================] - 13s 12ms/step - loss: 0.9838 - accuracy: 0.6154 - val_loss: 1.2708 - val_accuracy: 0.4955 - lr: 1.0000e-04\n", + "Epoch 34/100\n", + "1047/1047 [==============================] - 12s 12ms/step - loss: 0.9778 - accuracy: 0.6135 - val_loss: 1.2651 - val_accuracy: 0.5032 - lr: 1.0000e-04\n", + "Epoch 35/100\n", + "1047/1047 [==============================] - 12s 12ms/step - loss: 0.9610 - accuracy: 0.6234 - val_loss: 1.3275 - val_accuracy: 0.4751 - lr: 1.0000e-04\n", + "Epoch 36/100\n", + "1047/1047 [==============================] - 13s 13ms/step - loss: 0.9461 - accuracy: 0.6349 - val_loss: 1.2683 - val_accuracy: 0.4971 - lr: 1.0000e-04\n", + "Epoch 37/100\n", + "1047/1047 [==============================] - 12s 11ms/step - loss: 0.9443 - accuracy: 0.6332 - val_loss: 1.2852 - val_accuracy: 0.4923 - lr: 1.0000e-04\n", + "Epoch 38/100\n", + "1047/1047 [==============================] - 12s 12ms/step - loss: 0.9169 - accuracy: 0.6452 - val_loss: 1.2813 - val_accuracy: 0.4961 - lr: 1.0000e-04\n", + "Epoch 39/100\n", + "1047/1047 [==============================] - 13s 12ms/step - loss: 0.9133 - accuracy: 0.6436 - val_loss: 1.2613 - val_accuracy: 0.5050 - lr: 1.0000e-04\n", + "Epoch 40/100\n", + "1047/1047 [==============================] - 12s 12ms/step - loss: 0.8981 - accuracy: 0.6509 - val_loss: 1.2701 - val_accuracy: 0.5084 - lr: 1.0000e-04\n", + "Epoch 41/100\n", + "1047/1047 [==============================] - 13s 12ms/step - loss: 0.8874 - accuracy: 0.6515 - val_loss: 1.2848 - val_accuracy: 0.4928 - lr: 1.0000e-04\n", + "Epoch 42/100\n", + "1047/1047 [==============================] - 13s 12ms/step - loss: 0.8712 - accuracy: 0.6620 - val_loss: 1.2626 - val_accuracy: 0.5052 - lr: 1.0000e-04\n", + "Epoch 43/100\n", + "1047/1047 [==============================] - 15s 14ms/step - loss: 0.8702 - accuracy: 0.6597 - val_loss: 1.2687 - val_accuracy: 0.5109 - lr: 1.0000e-04\n", + "Epoch 44/100\n", + "1047/1047 [==============================] - 13s 12ms/step - loss: 0.8500 - accuracy: 0.6694 - val_loss: 1.2604 - val_accuracy: 0.5133 - lr: 1.0000e-04\n", + "Epoch 45/100\n", + "1047/1047 [==============================] - 12s 12ms/step - loss: 0.8305 - accuracy: 0.6759 - val_loss: 1.2698 - val_accuracy: 0.5122 - lr: 1.0000e-04\n", + "Epoch 46/100\n", + "1047/1047 [==============================] - 13s 12ms/step - loss: 0.8266 - accuracy: 0.6805 - val_loss: 1.2949 - val_accuracy: 0.5043 - lr: 1.0000e-04\n", + "Epoch 47/100\n", + "1047/1047 [==============================] - 13s 12ms/step - loss: 0.8132 - accuracy: 0.6860 - val_loss: 1.2778 - val_accuracy: 0.5021 - lr: 1.0000e-04\n", + "Epoch 48/100\n", + "1047/1047 [==============================] - 13s 12ms/step - loss: 0.7994 - accuracy: 0.6940 - val_loss: 1.2740 - val_accuracy: 0.5091 - lr: 1.0000e-04\n", + "Epoch 49/100\n", + "1047/1047 [==============================] - 13s 13ms/step - loss: 0.7836 - accuracy: 0.6936 - val_loss: 1.2925 - val_accuracy: 0.5070 - lr: 1.0000e-04\n", + "Epoch 50/100\n", + "1047/1047 [==============================] - 12s 12ms/step - loss: 0.7757 - accuracy: 0.7038 - val_loss: 1.3190 - val_accuracy: 0.5011 - lr: 1.0000e-04\n", + "Epoch 51/100\n", + "1047/1047 [==============================] - 12s 11ms/step - loss: 0.7679 - accuracy: 0.7001 - val_loss: 1.2861 - val_accuracy: 0.5027 - lr: 1.0000e-04\n", + "Epoch 52/100\n", + "1047/1047 [==============================] - 12s 12ms/step - loss: 0.7542 - accuracy: 0.7114 - val_loss: 1.3435 - val_accuracy: 0.4927 - lr: 1.0000e-04\n", + "Epoch 53/100\n", + "1047/1047 [==============================] - 13s 12ms/step - loss: 0.7459 - accuracy: 0.7093 - val_loss: 1.3164 - val_accuracy: 0.5072 - lr: 1.0000e-04\n", + "Epoch 54/100\n", + "1047/1047 [==============================] - 13s 12ms/step - loss: 0.7287 - accuracy: 0.7193 - val_loss: 1.2878 - val_accuracy: 0.5188 - lr: 1.0000e-04\n", + "Epoch 55/100\n", + "1047/1047 [==============================] - 13s 12ms/step - loss: 0.7178 - accuracy: 0.7262 - val_loss: 1.3178 - val_accuracy: 0.5054 - lr: 1.0000e-04\n", + "Epoch 56/100\n", + "1047/1047 [==============================] - 13s 12ms/step - loss: 0.7076 - accuracy: 0.7258 - val_loss: 1.3746 - val_accuracy: 0.4912 - lr: 1.0000e-04\n", + "Epoch 57/100\n", + "1047/1047 [==============================] - 12s 11ms/step - loss: 0.6955 - accuracy: 0.7306 - val_loss: 1.3457 - val_accuracy: 0.5097 - lr: 1.0000e-04\n", + "Epoch 58/100\n", + "1047/1047 [==============================] - 13s 12ms/step - loss: 0.6843 - accuracy: 0.7364 - val_loss: 1.3558 - val_accuracy: 0.5000 - lr: 1.0000e-04\n", + "Epoch 59/100\n", + "1047/1047 [==============================] - 12s 12ms/step - loss: 0.6790 - accuracy: 0.7370 - val_loss: 1.3310 - val_accuracy: 0.5150 - lr: 1.0000e-04\n", + "Epoch 60/100\n", + "1047/1047 [==============================] - 13s 12ms/step - loss: 0.6683 - accuracy: 0.7431 - val_loss: 1.3515 - val_accuracy: 0.5127 - lr: 1.0000e-04\n", + "Epoch 61/100\n", + "1047/1047 [==============================] - 13s 12ms/step - loss: 0.6628 - accuracy: 0.7419 - val_loss: 1.3877 - val_accuracy: 0.4955 - lr: 1.0000e-04\n", + "Epoch 62/100\n", + "1047/1047 [==============================] - 12s 11ms/step - loss: 0.6462 - accuracy: 0.7501 - val_loss: 1.3549 - val_accuracy: 0.5202 - lr: 1.0000e-04\n", + "Epoch 63/100\n", + "1047/1047 [==============================] - 13s 12ms/step - loss: 0.6305 - accuracy: 0.7597 - val_loss: 1.3709 - val_accuracy: 0.5109 - lr: 1.0000e-04\n", + "Epoch 64/100\n", + "1047/1047 [==============================] - 13s 12ms/step - loss: 0.6245 - accuracy: 0.7610 - val_loss: 1.3442 - val_accuracy: 0.5269 - lr: 1.0000e-04\n", + "Epoch 00064: early stopping\n" + ] + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "<keras.callbacks.History at 0x7a8558211bd0>" + ] + }, + "metadata": {}, + "execution_count": 22 + } + ], + "source": [ + "rlrp = ReduceLROnPlateau(monitor='loss', factor=0.4, verbose=0, patience=2, min_lr=0.0000001)\n", + "es = EarlyStopping(monitor='val_loss', mode='min', verbose=1, patience=20)\n", + "\n", + "model.fit(x_train, y_train, batch_size=16, epochs=100, validation_data=(x_test, y_test), callbacks=[es, rlrp])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "UiuEkw7DF_M5" + }, + "source": [ + "We can see that the accuracy of our model is not very high. This is because speech data is more complex than other forms of data and much more training data and/or preprocessing techniques are required to build a good speech emotion classifier. If you want to increase the accuracy, you can use multiple datasets instead of just one, and use more features from the Librosa library. You can also try experimenting with LSTM layers in the model. Here are some of the popular speech emo [...] + "* [RAVDESS](https://www.kaggle.com/datasets/uwrfkaggler/ravdess-emotional-speech-audio)\n", + "* [LSSED](https://github.com/tobefans/LSSED)\n", + "* [TESS](https://www.kaggle.com/datasets/ejlok1/toronto-emotional-speech-set-tess)\n", + "* [IEMOCAP](https://www.kaggle.com/datasets/columbine/iemocap)\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "btAlJmJjqgGS" + }, + "source": [ + "### Saving model in Google Cloud Bucket\n", + "In our final Beam pipeline, we will use RunInference. For that, we need to have a pretrained model stored in a location that is accessible to a model handler. Storing the model in a Google Cloud Bucket is the easiest way to do this." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "NoOhXDcUMYaA", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "bb7b2af8-86b3-48a5-a404-edf935147bac" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "WARNING:absl:Found untraced functions such as _jit_compiled_convolution_op, _jit_compiled_convolution_op, _jit_compiled_convolution_op, _jit_compiled_convolution_op, _jit_compiled_convolution_op while saving (showing 5 of 8). These functions will not be directly callable after loading.\n" + ] + } + ], + "source": [ + "save_model_dir = '' # Add the link to you GCS bucket here\n", + "model.save(save_model_dir)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "KRrMuKwmqo1I" + }, + "source": [ + "### Creating a model handler\n", + "A model handler is used to save, load and manage trained models. We have used TFModelHandlerNumpy since our model was built using TensorFlow and takes NumPy arrays as input." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "0dbI7-9KMbt6" + }, + "outputs": [], + "source": [ + "model_handler = TFModelHandlerNumpy(save_model_dir)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ZoStfR5HqVKV" + }, + "source": [ + "## Preprocessing functions for Beam pipeline\n", + "We need to define some functions to perform the same preprocessing tasks we did on our training data. We can't reuse the previously defined function directly since they processed multidimensional data, and in a pipeline we deal with a single data item, which requires different methods." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2MmGkjcpYJuf" + }, + "source": [ + "This function loads the audio data using Librosa and extracts features using the previously defined function." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "MktIcCkzYJ8-" + }, + "outputs": [], + "source": [ + "def feature_extraction(element):\n", + " data, sample_rate = librosa.load(path, duration=2.5, offset=0.6)\n", + " return extract_features(data, sample_rate)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "s8S4N1W-MeAG" + }, + "source": [ + "Here we have scaled the data using standardization. The data is transformed such that it's mean is 0 and standard deviation is 1." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "XlYhtZf-p01y" + }, + "outputs": [], + "source": [ + "def scaling(element):\n", + " element = (element-np.mean(element))/np.std(element)\n", + " return element" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "1MOL8vfzMtpX" + }, + "source": [ + "In the end we will save our predictions in a list. RunInference returns an array of probabilities for each class. We select the maximum probability, replace it by 1, and replace all other values with 0. Now our new list is in a standard one hot encoded format, and we can use the inverse transform function of the OneHotEncoder to return which class the resultant array represents." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ehwQB-PdqwWh" + }, + "outputs": [], + "source": [ + "predictions = []" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Dos4eVTIjfpq" + }, + "outputs": [], + "source": [ + "from tensorflow.python.ops.numpy_ops import np_config\n", + "np_config.enable_numpy_behavior()\n", + "def save_predictions(element):\n", + " list_of_predictions = element.inference.tolist()\n", + " highest_prediction = max(list_of_predictions)\n", + " l = []\n", + " for i in range(len(list_of_predictions)):\n", + " if list_of_predictions[i] == highest_prediction:\n", + " l.append(1)\n", + " else:\n", + " l.append(0);\n", + " ans = encoder.inverse_transform(np.array(l).reshape(1,-1))[0][0]\n", + " predictions.append(ans)\n", + " print(ans)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "pZjR4A01q5iu" + }, + "source": [ + "## Building the Beam Pipeline\n", + "This pipeline performs the following tasks\n", + "* Creates a PCollection of input paths\n", + "* Extracts features using the previously defined functions\n", + "* Performs scaling\n", + "* Runs inference on new data using the previously trained model\n", + "* Saves predictions in a list" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ZAwVSGg_mWB_" + }, + "outputs": [], + "source": [ + "pipeline_input = Crema_df[:2].Path" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "wJqMVam6lIHX", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "f6e5bb00-ff15-4aca-f8cd-98ab867a2f07" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "sad\n", + "sad\n" + ] + } + ], + "source": [ + "with beam.Pipeline() as p:\n", + " _ = (p | beam.Create(pipeline_input)\n", + " | beam.Map(feature_extraction)\n", + " | beam.Map(scaling)\n", + " | RunInference(model_handler)\n", + " | beam.Map(save_predictions)\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Pt5zEoXxS6wh", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 112 + }, + "outputId": "8d2bf9c8-2a46-4ff5-d200-af370aaa24a7" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + " Emotion Path\n", + "0 neutral /content/gdrive/My Drive/CREMA//1079_TIE_NEU_X...\n", + "1 sad /content/gdrive/My Drive/CREMA//1079_TIE_SAD_X..." + ], + "text/html": [ + "\n", + " <div id=\"df-dd08647c-7a23-45bd-ac01-28b2e308b9ec\" class=\"colab-df-container\">\n", + " <div>\n", + "<style scoped>\n", + " .dataframe tbody tr th:only-of-type {\n", + " vertical-align: middle;\n", + " }\n", + "\n", + " .dataframe tbody tr th {\n", + " vertical-align: top;\n", + " }\n", + "\n", + " .dataframe thead th {\n", + " text-align: right;\n", + " }\n", + "</style>\n", + "<table border=\"1\" class=\"dataframe\">\n", + " <thead>\n", + " <tr style=\"text-align: right;\">\n", + " <th></th>\n", + " <th>Emotion</th>\n", + " <th>Path</th>\n", + " </tr>\n", + " </thead>\n", + " <tbody>\n", + " <tr>\n", + " <th>0</th>\n", + " <td>neutral</td>\n", + " <td>/content/gdrive/My Drive/CREMA//1079_TIE_NEU_X...</td>\n", + " </tr>\n", + " <tr>\n", + " <th>1</th>\n", + " <td>sad</td>\n", + " <td>/content/gdrive/My Drive/CREMA//1079_TIE_SAD_X...</td>\n", + " </tr>\n", + " </tbody>\n", + "</table>\n", + "</div>\n", + " <div class=\"colab-df-buttons\">\n", + "\n", + " <div class=\"colab-df-container\">\n", + " <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-dd08647c-7a23-45bd-ac01-28b2e308b9ec')\"\n", + " title=\"Convert this dataframe to an interactive table.\"\n", + " style=\"display:none;\">\n", + "\n", + " <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\" viewBox=\"0 -960 960 960\">\n", + " <path d=\"M120-120v-720h720v720H120Zm60-500h600v-160H180v160Zm220 220h160v-160H400v160Zm0 220h160v-160H400v160ZM180-400h160v-160H180v160Zm440 0h160v-160H620v160ZM180-180h160v-160H180v160Zm440 0h160v-160H620v160Z\"/>\n", + " </svg>\n", + " </button>\n", + "\n", + " <style>\n", + " .colab-df-container {\n", + " display:flex;\n", + " gap: 12px;\n", + " }\n", + "\n", + " .colab-df-convert {\n", + " background-color: #E8F0FE;\n", + " border: none;\n", + " border-radius: 50%;\n", + " cursor: pointer;\n", + " display: none;\n", + " fill: #1967D2;\n", + " height: 32px;\n", + " padding: 0 0 0 0;\n", + " width: 32px;\n", + " }\n", + "\n", + " .colab-df-convert:hover {\n", + " background-color: #E2EBFA;\n", + " box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n", + " fill: #174EA6;\n", + " }\n", + "\n", + " .colab-df-buttons div {\n", + " margin-bottom: 4px;\n", + " }\n", + "\n", + " [theme=dark] .colab-df-convert {\n", + " background-color: #3B4455;\n", + " fill: #D2E3FC;\n", + " }\n", + "\n", + " [theme=dark] .colab-df-convert:hover {\n", + " background-color: #434B5C;\n", + " box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n", + " filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n", + " fill: #FFFFFF;\n", + " }\n", + " </style>\n", + "\n", + " <script>\n", + " const buttonEl =\n", + " document.querySelector('#df-dd08647c-7a23-45bd-ac01-28b2e308b9ec button.colab-df-convert');\n", + " buttonEl.style.display =\n", + " google.colab.kernel.accessAllowed ? 'block' : 'none';\n", + "\n", + " async function convertToInteractive(key) {\n", + " const element = document.querySelector('#df-dd08647c-7a23-45bd-ac01-28b2e308b9ec');\n", + " const dataTable =\n", + " await google.colab.kernel.invokeFunction('convertToInteractive',\n", + " [key], {});\n", + " if (!dataTable) return;\n", + "\n", + " const docLinkHtml = 'Like what you see? Visit the ' +\n", + " '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n", + " + ' to learn more about interactive tables.';\n", + " element.innerHTML = '';\n", + " dataTable['output_type'] = 'display_data';\n", + " await google.colab.output.renderOutput(dataTable, element);\n", + " const docLink = document.createElement('div');\n", + " docLink.innerHTML = docLinkHtml;\n", + " element.appendChild(docLink);\n", + " }\n", + " </script>\n", + " </div>\n", + "\n", + "\n", + "<div id=\"df-9fdd2dec-4afa-423c-8f3e-b29aab4d5d1a\">\n", + " <button class=\"colab-df-quickchart\" onclick=\"quickchart('df-9fdd2dec-4afa-423c-8f3e-b29aab4d5d1a')\"\n", + " title=\"Suggest charts.\"\n", + " style=\"display:none;\">\n", + "\n", + "<svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n", + " width=\"24px\">\n", + " <g>\n", + " <path d=\"M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z\"/>\n", + " </g>\n", + "</svg>\n", + " </button>\n", + "\n", + "<style>\n", + " .colab-df-quickchart {\n", + " background-color: #E8F0FE;\n", + " border: none;\n", + " border-radius: 50%;\n", + " cursor: pointer;\n", + " display: none;\n", + " fill: #1967D2;\n", + " height: 32px;\n", + " padding: 0 0 0 0;\n", + " width: 32px;\n", + " }\n", + "\n", + " .colab-df-quickchart:hover {\n", + " background-color: #E2EBFA;\n", + " box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n", + " fill: #174EA6;\n", + " }\n", + "\n", + " [theme=dark] .colab-df-quickchart {\n", + " background-color: #3B4455;\n", + " fill: #D2E3FC;\n", + " }\n", + "\n", + " [theme=dark] .colab-df-quickchart:hover {\n", + " background-color: #434B5C;\n", + " box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n", + " filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n", + " fill: #FFFFFF;\n", + " }\n", + "</style>\n", + "\n", + " <script>\n", + " async function quickchart(key) {\n", + " const charts = await google.colab.kernel.invokeFunction(\n", + " 'suggestCharts', [key], {});\n", + " }\n", + " (() => {\n", + " let quickchartButtonEl =\n", + " document.querySelector('#df-9fdd2dec-4afa-423c-8f3e-b29aab4d5d1a button');\n", + " quickchartButtonEl.style.display =\n", + " google.colab.kernel.accessAllowed ? 'block' : 'none';\n", + " })();\n", + " </script>\n", + "</div>\n", + " </div>\n", + " </div>\n" + ] + }, + "metadata": {}, + "execution_count": 31 + } + ], + "source": [ + "Crema_df[:2]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "WNjJSGh8rpbP", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 75 + }, + "outputId": "74b9d9d4-daab-4b7d-c5db-c07078ea5d46" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "<IPython.lib.display.Audio object>" + ], + "text/html": [ + "\n", + " <audio controls=\"controls\" >\n", + " <source src=\"data:audio/x-wav;base64,UklGRn5eAQBXQVZFZm10IBAAAAABAAEAgD4AAAB9AAACABAAZGF0YVpeAQAl/zb/Ff8O/xj/CP/z/gP/5P4L//3+4v7v/uj+4v7t/uL+4P7A/tX+wf6q/qD+lP6V/oT+pP6g/sT+yv7V/tb+9P4k/wj/CP9C/1T/T/+C/33/r/+a/6z/xP/u/9r//f8HAAIA/f8LAOH//v/8//7/6f/0/wMA9f/p////7f/v/9f/y//K/83/x//F/7H/yP+8/7n/yf/R/8f/vv+f/8f/of+u/6n/of+4/7//uv+5/9//1v/s/93/5//k/xIA2f/s//H/5v/t//3/+v/t//b/GgANACkAIQAaACkAJwBFAEIAVgBmAIoAgACYAK0AsgDNAL8AyQDNALsA1gDOAL4AxADhAPcA5gDpAM8A9AD [...] + " Your browser does not support the audio element.\n", + " </audio>\n", + " " + ] + }, + "metadata": {}, + "execution_count": 32 + } + ], + "source": [ + "from IPython.display import Audio\n", + "Audio(Crema_df.iloc[0].Path)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "LN69bC_ksG0k", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 75 + }, + "outputId": "8b7d1eea-95fb-47d8-9643-ccd34dd3b87c" + }, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "<IPython.lib.display.Audio object>" + ], + "text/html": [ + "\n", + " <audio controls=\"controls\" >\n", + " <source src=\"data:audio/x-wav;base64,UklGRl6MAQBXQVZFZm10IBAAAAABAAEAgD4AAAB9AAACABAAZGF0YTqMAQDl/+7/1//f/+D/zf/2/wsABgAAAPP/+f8JAPT/BgD4/wsAEgASAND/5P/d/+z/3v/v/xEA8v/s//T/4v/7/w8ACQAVADUAXABkAJ4AkQCmALYA9ADiAPcA+wD3AA0BFQERAQQBGgEOAQEB/wDvAPEA1ADFALkAvQCrAJ8AnACMAGIAVgB2AHIAkgCYAIoAYwBNAFoAVgBOAD4ARQBBADUACgAuAB8A6P/q/9j/+v/0/wkA4//6/wMAyv/U//H/7//3//b/AAD1//v/CAADAAwACwACACYAEgApAAQACwAMAOL/3v/T//D/CADU/9f/3//p/9X/3//d/9z/yf/+//j//P8EAAYACQAVABMAKQA [...] + " Your browser does not support the audio element.\n", + " </audio>\n", + " " + ] + }, + "metadata": {}, + "execution_count": 33 + } + ], + "source": [ + "Audio(Crema_df.iloc[1].Path)" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file