Heim  >  Artikel  >  Backend-Entwicklung  >  Tensorflow-Musikvorhersage

Tensorflow-Musikvorhersage

WBOY
WBOYOriginal
2024-08-27 06:03:08615Durchsuche

Tensorflow music prediction

In diesem Artikel zeige ich, wie man Tensorflow verwendet, um einen Musikstil vorherzusagen.
In meinem Beispiel vergleiche ich Techno und klassische Musik.

Den Code finden Sie auf meinem Github:
https://github.com/victordalet/sound_to_partition


I – Datensatz

Im ersten Schritt müssen Sie einen Datensatzordner erstellen und darin einen Ordner für den Musikstil hinzufügen. Ich füge beispielsweise einen Techno-Ordner und einen Classic-Ordner hinzu, in die ich meinen WAV-Song einfüge.

II - Zug

Ich erstelle eine Zugdatei mit den Argumenten max_epochs, die vervollständigt werden sollen.

Ändern Sie die Klassen im Konstruktor, die Ihrem Verzeichnis im Datensatzordner entsprechen.

Bei der Lade- und Verarbeitungsmethode rufe ich die WAV-Datei aus einem anderen Verzeichnis ab und erhalte das Spektogramm.

Zu Trainingszwecken verwende ich die Keras-Faltungen und das Keras-Modell.

import os
import sys
from typing import List

import librosa
import numpy as np
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Flatten, Dense
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from sklearn.model_selection import train_test_split
from tensorflow.keras.utils import to_categorical
from tensorflow.image import resize



class Train:

    def __init__(self):
        self.X_train = None
        self.X_test = None
        self.y_train = None
        self.y_test = None
        self.data_dir: str = 'dataset'
        self.classes: List[str] = ['techno','classic']
        self.max_epochs: int = int(sys.argv[1])

    @staticmethod
    def load_and_preprocess_data(data_dir, classes, target_shape=(128, 128)):
        data = []
        labels = []

        for i, class_name in enumerate(classes):
            class_dir = os.path.join(data_dir, class_name)
            for filename in os.listdir(class_dir):
                if filename.endswith('.wav'):
                    file_path = os.path.join(class_dir, filename)
                    audio_data, sample_rate = librosa.load(file_path, sr=None)
                    mel_spectrogram = librosa.feature.melspectrogram(y=audio_data, sr=sample_rate)
                    mel_spectrogram = resize(np.expand_dims(mel_spectrogram, axis=-1), target_shape)
                    data.append(mel_spectrogram)
                    labels.append(i)

        return np.array(data), np.array(labels)

    def create_model(self):
        data, labels = self.load_and_preprocess_data(self.data_dir, self.classes)
        labels = to_categorical(labels, num_classes=len(self.classes))  # Convert labels to one-hot encoding
        self.X_train, self.X_test, self.y_train, self.y_test = train_test_split(data, labels, test_size=0.2,
                                                                                random_state=42)

        input_shape = self.X_train[0].shape
        input_layer = Input(shape=input_shape)
        x = Conv2D(32, (3, 3), activation='relu')(input_layer)
        x = MaxPooling2D((2, 2))(x)
        x = Conv2D(64, (3, 3), activation='relu')(x)
        x = MaxPooling2D((2, 2))(x)
        x = Flatten()(x)
        x = Dense(64, activation='relu')(x)
        output_layer = Dense(len(self.classes), activation='softmax')(x)
        self.model = Model(input_layer, output_layer)

        self.model.compile(optimizer=Adam(learning_rate=0.001), loss='categorical_crossentropy', metrics=['accuracy'])

    def train_model(self):
        self.model.fit(self.X_train, self.y_train, epochs=self.max_epochs, batch_size=32,
                       validation_data=(self.X_test, self.y_test))
        test_accuracy = self.model.evaluate(self.X_test, self.y_test, verbose=0)
        print(test_accuracy[1])

    def save_model(self):
        self.model.save('weight.h5')


if __name__ == '__main__':
    train = Train()
    train.create_model()
    train.train_model()
    train.save_model()

III - Test

Um das Modell zu testen und zu verwenden, habe ich diese Klasse erstellt, um das Gewicht abzurufen und den Stil der Musik vorherzusagen.

Vergessen Sie nicht, dem Konstruktor die richtigen Klassen hinzuzufügen.

from typing import List

import librosa
import numpy as np
from tensorflow.keras.models import load_model
from tensorflow.image import resize
import tensorflow as tf



class Test:

    def __init__(self, audio_file_path: str):
        self.model = load_model('weight.h5')
        self.target_shape = (128, 128)
        self.classes: List[str] = ['techno','classic']
        self.audio_file_path: str = audio_file_path

    def test_audio(self, file_path, model):
        audio_data, sample_rate = librosa.load(file_path, sr=None)
        mel_spectrogram = librosa.feature.melspectrogram(y=audio_data, sr=sample_rate)
        mel_spectrogram = resize(np.expand_dims(mel_spectrogram, axis=-1), self.target_shape)
        mel_spectrogram = tf.reshape(mel_spectrogram, (1,) + self.target_shape + (1,))

        predictions = model.predict(mel_spectrogram)

        class_probabilities = predictions[0]

        predicted_class_index = np.argmax(class_probabilities)

        return class_probabilities, predicted_class_index

    def test(self):
        class_probabilities, predicted_class_index = self.test_audio(self.audio_file_path, self.model)

        for i, class_label in enumerate(self.classes):
            probability = class_probabilities[i]
            print(f'Class: {class_label}, Probability: {probability:.4f}')

        predicted_class = self.classes[predicted_class_index]
        accuracy = class_probabilities[predicted_class_index]
        print(f'The audio is classified as: {predicted_class}')
        print(f'Accuracy: {accuracy:.4f}')

Das obige ist der detaillierte Inhalt vonTensorflow-Musikvorhersage. Für weitere Informationen folgen Sie bitte anderen verwandten Artikeln auf der PHP chinesischen Website!

Stellungnahme:
Der Inhalt dieses Artikels wird freiwillig von Internetnutzern beigesteuert und das Urheberrecht liegt beim ursprünglichen Autor. Diese Website übernimmt keine entsprechende rechtliche Verantwortung. Wenn Sie Inhalte finden, bei denen der Verdacht eines Plagiats oder einer Rechtsverletzung besteht, wenden Sie sich bitte an admin@php.cn