Heim >Technologie-Peripheriegeräte >KI >Warum sind baumbasierte Modelle immer noch besser als Deep Learning für Tabellendaten?
Deep Learning hat in Bereichen wie Bildern, Sprache und sogar Audio große Fortschritte gemacht. Allerdings schneidet Deep Learning bei der Verarbeitung tabellarischer Daten nur mittelmäßig ab. Da Tabellendaten Merkmale wie ungleichmäßige Merkmale, kleine Stichprobengröße und große Extremwerte aufweisen, ist es schwierig, entsprechende Invarianten zu finden.
Baumbasierte Modelle sind nicht differenzierbar und können nicht gemeinsam mit Deep-Learning-Modulen trainiert werden, daher ist die Erstellung tabellenspezifischer Deep-Learning-Architekturen ein sehr aktives Forschungsgebiet. Viele Studien haben behauptet, dass sie baumbasierte Modelle übertreffen oder mit ihnen konkurrieren können, doch ihre Studien stießen auf große Skepsis.
Die Tatsache, dass es beim Lernen anhand tabellarischer Daten an etablierten Benchmarks mangelt, gibt Forschern viel Freiheit bei der Bewertung ihrer Methoden. Darüber hinaus sind die meisten online verfügbaren tabellarischen Datensätze im Vergleich zu Benchmarks in anderen Subdomänen des maschinellen Lernens klein, was die Auswertung erschwert.
Um diese Bedenken auszuräumen, schlagen Forscher des französischen Nationalinstituts für Information und Automatisierung, der Universität Sorbonne und anderer Institutionen einen tabellarischen Datenbenchmark vor, der die neuesten Deep-Learning-Modelle bewerten und zeigen kann, dass baumbasierte Modelle effektiver sind mittelgroßes Still-SOTA für tabellarische Datensätze.
Für diese Schlussfolgerung liefert der Artikel schlüssige Beweise: Bei tabellarischen Daten ist es einfacher, mit baumbasierten Methoden gute Vorhersagen zu treffen (sogar mit modernen Architekturen).
Adresse des Papiers: https://hal.archives-ouvertes.fr/hal-03723551/documentEs ist erwähnenswert, dass einer der Autoren des Papiers Gaël Varoquaux ist, ein Scikit-learn-Autor. Einer der Leiter des Programms. Das Projekt hat sich mittlerweile zu einer der beliebtesten Bibliotheken für maschinelles Lernen auf GitHub entwickelt. Der Artikel „Scikit-learn: Machine learning in Python“ von Gaël Varoquaux hat 58.949 Zitate.
Der Beitrag dieses Artikels lässt sich wie folgt zusammenfassen:
Diese Studie erstellt einen neuen Benchmark für Tabellendaten (Auswahl von 45 offenen Datensätzen) und teilt diese Datensätze über OpenML, wodurch sie einfach zu verwenden sind .
Diese Studie vergleicht Deep-Learning-Modelle und baumbasierte Modelle unter verschiedenen Einstellungen für Tabellendaten und berücksichtigt die Kosten für die Auswahl von Hyperparametern. Die Studie teilt auch Rohergebnisse aus Zufallssuchen, die es Forschern ermöglichen, neue Algorithmen für ein festes Budget für die Hyperparameteroptimierung kostengünstig zu testen.
Der neue Benchmark bezieht sich auf 45 Tabellendatensätze, und der Auswahlbenchmark lautet wie folgt:
Unter den baumbasierten Modellen wählten die Forscher drei SOTA-Modelle: RandomForest von Scikit Learn, GradientBoostingTrees (GBTs) und XGBoost. Die Studie führte die folgenden Benchmarks für Deep-Modelle durch: MLP, Resnet, FT Transformer, SAINT. Abbildung 1 und Abbildung 2 zeigen die Benchmark-Ergebnisse für verschiedene Arten von Datensätzen
Induktiver Bias. Baumbasierte Modelle schlagen neuronale Netze bei einer Vielzahl von Hyperparameteroptionen. Tatsächlich haben die besten Methoden zur Verarbeitung tabellarischer Daten zwei gemeinsame Eigenschaften: Sie sind Ensemble-Methoden, Bagging (Random Forests) oder Boosting (XGBoost, GBT), und die in diesen Methoden verwendeten schwachen Lernenden sind Entscheidungsbäume.
Ergebnis 1: Neuronale Netze (NN) neigen dazu, Lösungen zu stark zu glätten
Wie in Abbildung 3 gezeigt, nimmt die Zielfunktion auf dem glatten Trainingssatz bei kleineren Maßstäben basierend auf der Genauigkeit erheblich ab des Baummodells, wird aber kaum Auswirkungen auf das NN haben. Diese Ergebnisse deuten darauf hin, dass die Zielfunktion im Datensatz nicht glatt ist und dass NN im Vergleich zu baumbasierten Modellen Schwierigkeiten hat, sich an diese unregelmäßigen Funktionen anzupassen. Dies steht im Einklang mit den Erkenntnissen von Rahaman et al., die herausfanden, dass NNs auf niederfrequente Funktionen ausgerichtet sind. Entscheidungsbaumbasierte Modelle lernen stückweise konstante Funktionen ohne solche Verzerrungen.
Ergebnis 2: Nicht informative Merkmale können sich stärker auf MLP-ähnliche NN auswirken. Der tabellarische Datensatz enthält viele nicht informative Merkmale. Für jeden Datensatz basiert die Studie auf den Merkmalen wird sich dafür entscheiden, einen bestimmten Anteil an Features zu verwerfen (normalerweise nach zufälliger Gesamtstruktur geordnet). Wie aus Abbildung 4 ersichtlich ist, hat das Entfernen von mehr als der Hälfte der Merkmale kaum Auswirkungen auf die Klassifizierungsgenauigkeit von GBT.
Abbildung 5 Es ist ersichtlich, dass das Entfernen nicht-informativer Funktionen (5a) die Leistungslücke zwischen MLP (Resnet) und anderen Modellen (FT Transformers und baumbasierte Modelle) verringert, während nicht-informative Funktionen hinzugefügt werden. informative Merkmale Merkmale vergrößern die Lücke, was darauf hindeutet, dass MLP gegenüber nicht informativen Merkmalen weniger robust ist. Wenn der Forscher in Abbildung 5a einen größeren Anteil an Merkmalen entfernt, werden auch nützliche Informationsmerkmale entsprechend entfernt. Abbildung 5b zeigt, dass die durch das Entfernen dieser Merkmale verursachte Verringerung der Genauigkeit durch das Entfernen nicht informativer Merkmale ausgeglichen werden kann, was für MLP im Vergleich zu anderen Modellen hilfreicher ist (gleichzeitig entfernt diese Studie auch redundante Merkmale und hat keinen Einfluss auf das Modell). Leistung).
Ergebnis 3: Durch Rotation sind die Daten nichtinvariant
Warum ist MLP im Vergleich zu anderen Modellen anfälliger für nichtinformative Merkmale? Eine Antwort ist, dass MLPs rotationsinvariant sind: Der Prozess des Erlernens eines MLP auf dem Trainingssatz und seiner Auswertung auf dem Testsatz ist invariant, wenn Rotationen auf Trainings- und Testsatzfunktionen angewendet werden. Tatsächlich weist jeder rotationsinvariante Lernprozess im schlimmsten Fall eine Stichprobenkomplexität auf, die zumindest in der Anzahl irrelevanter Merkmale linear zunimmt. Um nutzlose Features zu entfernen, muss der rotationsinvariante Algorithmus intuitiv zunächst die ursprüngliche Ausrichtung des Features finden und dann das am wenigsten informative Feature auswählen.
Abbildung 6a zeigt die Änderung der Testgenauigkeit, wenn der Datensatz zufällig rotiert wird, und bestätigt, dass nur Resnets rotationsinvariant sind. Insbesondere kehrt die zufällige Rotation die Reihenfolge der Leistung um: Das Ergebnis sind NNs über baumbasierten Modellen und Resnets über FT-Transformatoren, was darauf hindeutet, dass Rotationsinvarianz unerwünscht ist. Tatsächlich haben tabellarische Daten oft individuelle Bedeutungen, wie z. B. Alter, Gewicht usw. Wie in Abbildung 6b dargestellt: Das Entfernen der unwichtigsten Hälfte der Features in jedem Datensatz (vor der Rotation) verringert die Leistung aller Modelle außer Resnets, aber im Vergleich zur Verwendung aller Features ohne Entfernen von Features ist der Rückgang geringer.
Das obige ist der detaillierte Inhalt vonWarum sind baumbasierte Modelle immer noch besser als Deep Learning für Tabellendaten?. Für weitere Informationen folgen Sie bitte anderen verwandten Artikeln auf der PHP chinesischen Website!