Maison  >  Article  >  développement back-end  >  Prédiction musicale Tensorflow

Prédiction musicale Tensorflow

WBOY
WBOYoriginal
2024-08-27 06:03:08615parcourir

Tensorflow music prediction

Dans cet article, je montre comment utiliser Tensorflow pour prédire un style de musique.
Dans mon exemple, je compare la techno et la musique classique.

Vous pouvez retrouver le code sur mon github :
https://github.com/victordalet/sound_to_partition


I - Ensemble de données

Pour la première étape, vous devez créer un dossier d'ensemble de données et à l'intérieur ajouter un dossier pour le style de musique, par exemple, j'ajoute un dossier techno et un dossier classique dans lesquels mettre mon son wav.

II-Trainer

Je crée un fichier train, avec les arguments max_epochs à compléter.

Modifiez les classes dans le constructeur qui correspondent à votre répertoire dans le dossier dataset.

Dans la méthode de chargement et de traitement, je récupère le fichier wav dans un répertoire différent et j'obtiens le spectrogramme.

À des fins de formation, j'utilise les convolutions et le modèle Keras.

import os
import sys
from typing import List

import librosa
import numpy as np
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Flatten, Dense
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from sklearn.model_selection import train_test_split
from tensorflow.keras.utils import to_categorical
from tensorflow.image import resize



class Train:

    def __init__(self):
        self.X_train = None
        self.X_test = None
        self.y_train = None
        self.y_test = None
        self.data_dir: str = 'dataset'
        self.classes: List[str] = ['techno','classic']
        self.max_epochs: int = int(sys.argv[1])

    @staticmethod
    def load_and_preprocess_data(data_dir, classes, target_shape=(128, 128)):
        data = []
        labels = []

        for i, class_name in enumerate(classes):
            class_dir = os.path.join(data_dir, class_name)
            for filename in os.listdir(class_dir):
                if filename.endswith('.wav'):
                    file_path = os.path.join(class_dir, filename)
                    audio_data, sample_rate = librosa.load(file_path, sr=None)
                    mel_spectrogram = librosa.feature.melspectrogram(y=audio_data, sr=sample_rate)
                    mel_spectrogram = resize(np.expand_dims(mel_spectrogram, axis=-1), target_shape)
                    data.append(mel_spectrogram)
                    labels.append(i)

        return np.array(data), np.array(labels)

    def create_model(self):
        data, labels = self.load_and_preprocess_data(self.data_dir, self.classes)
        labels = to_categorical(labels, num_classes=len(self.classes))  # Convert labels to one-hot encoding
        self.X_train, self.X_test, self.y_train, self.y_test = train_test_split(data, labels, test_size=0.2,
                                                                                random_state=42)

        input_shape = self.X_train[0].shape
        input_layer = Input(shape=input_shape)
        x = Conv2D(32, (3, 3), activation='relu')(input_layer)
        x = MaxPooling2D((2, 2))(x)
        x = Conv2D(64, (3, 3), activation='relu')(x)
        x = MaxPooling2D((2, 2))(x)
        x = Flatten()(x)
        x = Dense(64, activation='relu')(x)
        output_layer = Dense(len(self.classes), activation='softmax')(x)
        self.model = Model(input_layer, output_layer)

        self.model.compile(optimizer=Adam(learning_rate=0.001), loss='categorical_crossentropy', metrics=['accuracy'])

    def train_model(self):
        self.model.fit(self.X_train, self.y_train, epochs=self.max_epochs, batch_size=32,
                       validation_data=(self.X_test, self.y_test))
        test_accuracy = self.model.evaluate(self.X_test, self.y_test, verbose=0)
        print(test_accuracy[1])

    def save_model(self):
        self.model.save('weight.h5')


if __name__ == '__main__':
    train = Train()
    train.create_model()
    train.train_model()
    train.save_model()

III - Essai

Pour tester et utiliser le modèle, j'ai créé cette classe pour récupérer le poids et prédire le style de la musique.

N'oubliez pas d'ajouter les bonnes classes au constructeur.

from typing import List

import librosa
import numpy as np
from tensorflow.keras.models import load_model
from tensorflow.image import resize
import tensorflow as tf



class Test:

    def __init__(self, audio_file_path: str):
        self.model = load_model('weight.h5')
        self.target_shape = (128, 128)
        self.classes: List[str] = ['techno','classic']
        self.audio_file_path: str = audio_file_path

    def test_audio(self, file_path, model):
        audio_data, sample_rate = librosa.load(file_path, sr=None)
        mel_spectrogram = librosa.feature.melspectrogram(y=audio_data, sr=sample_rate)
        mel_spectrogram = resize(np.expand_dims(mel_spectrogram, axis=-1), self.target_shape)
        mel_spectrogram = tf.reshape(mel_spectrogram, (1,) + self.target_shape + (1,))

        predictions = model.predict(mel_spectrogram)

        class_probabilities = predictions[0]

        predicted_class_index = np.argmax(class_probabilities)

        return class_probabilities, predicted_class_index

    def test(self):
        class_probabilities, predicted_class_index = self.test_audio(self.audio_file_path, self.model)

        for i, class_label in enumerate(self.classes):
            probability = class_probabilities[i]
            print(f'Class: {class_label}, Probability: {probability:.4f}')

        predicted_class = self.classes[predicted_class_index]
        accuracy = class_probabilities[predicted_class_index]
        print(f'The audio is classified as: {predicted_class}')
        print(f'Accuracy: {accuracy:.4f}')

Ce qui précède est le contenu détaillé de. pour plus d'informations, suivez d'autres articles connexes sur le site Web de PHP en chinois!

Déclaration:
Le contenu de cet article est volontairement contribué par les internautes et les droits d'auteur appartiennent à l'auteur original. Ce site n'assume aucune responsabilité légale correspondante. Si vous trouvez un contenu suspecté de plagiat ou de contrefaçon, veuillez contacter admin@php.cn