Heim >Technologie-Peripheriegeräte >KI >Klassenungleichgewichtsproblem bei der Bildklassifizierung
Kategorienungleichgewichtsproblem bei der Bildklassifizierung, es sind spezifische Codebeispiele erforderlich
Zusammenfassung: Bei Bildklassifizierungsaufgaben können die Kategorien im Datensatz ein Ungleichgewichtsproblem aufweisen, das heißt, die Anzahl der Stichproben in einigen Kategorien ist viel größer als andere Kategorien. Dieses Klassenungleichgewicht kann sich negativ auf das Training und die Leistung des Modells auswirken. In diesem Artikel werden die Ursachen und Auswirkungen des Klassenungleichgewichtsproblems beschrieben und einige konkrete Codebeispiele zur Lösung des Problems bereitgestellt.
Das Klassenungleichgewichtsproblem hat einige negative Auswirkungen auf das Training und die Leistung des Modells. Erstens kann es aufgrund der geringen Stichprobenanzahl in einigen Kategorien dazu kommen, dass das Modell diese Kategorien falsch einschätzt. Bei einem Problem mit zwei Klassifizierungen beträgt die Anzahl der Stichproben in den beiden Kategorien beispielsweise 10 bzw. 1000. Wenn das Modell kein Lernen durchführt und alle Stichproben direkt als Kategorien mit einer größeren Anzahl von Stichproben vorhersagt, ist die Genauigkeit höher sehr hoch, aber in Wirklichkeit werden die Proben nicht effektiv klassifiziert. Zweitens kann das Modell aufgrund einer unausgewogenen Stichprobenverteilung dazu tendieren, Kategorien mit einer größeren Anzahl von Stichproben vorherzusagen, was zu einer schlechten Klassifizierungsleistung für andere Kategorien führt. Schließlich kann eine unausgewogene Kategorienverteilung dazu führen, dass die Trainingsstichproben des Modells für Minderheitenkategorien nicht ausreichen, was dazu führt, dass das erlernte Modell eine schlechte Generalisierungsfähigkeit für Minderheitskategorien aufweist.
Unterabtastung bezieht sich auf das zufällige Löschen einiger Stichproben aus Kategorien mit einer größeren Anzahl von Stichproben, sodass die Anzahl der Stichproben in jeder Kategorie näher beieinander liegt. Diese Methode ist einfach und unkompliziert, kann jedoch zu Informationsverlusten führen, da durch das Löschen von Beispielen möglicherweise einige wichtige Funktionen verloren gehen.
Oversampling bezieht sich auf das Kopieren einiger Samples aus Kategorien mit einer geringeren Anzahl an Samples, um die Anzahl der Samples in jeder Kategorie ausgewogener zu gestalten. Diese Methode kann die Anzahl der Stichproben erhöhen, kann jedoch zu Überanpassungsproblemen führen, da das Kopieren von Stichproben dazu führen kann, dass das Modell zu stark an den Trainingssatz angepasst wird und eine schlechte Generalisierungsfähigkeit aufweist.
Gewichtungsanpassung bezieht sich auf die unterschiedliche Gewichtung von Stichproben verschiedener Kategorien in der Verlustfunktion, sodass das Modell Kategorien mit einer geringeren Anzahl von Stichproben mehr Aufmerksamkeit schenkt. Diese Methode kann das Problem des Klassenungleichgewichts effektiv lösen, ohne zusätzliche Stichproben einzuführen. Der spezifische Ansatz besteht darin, das Gewicht jeder Kategorie in der Verlustfunktion durch Angabe eines Gewichtsvektors anzupassen, sodass Kategorien mit einer geringeren Anzahl von Stichproben größere Gewichte haben.
Hier ist ein Codebeispiel unter Verwendung des PyTorch-Frameworks, das zeigt, wie die Gewichtsanpassungsmethode zur Lösung des Klassenungleichgewichtsproblems verwendet wird:
import torch import torch.nn as nn import torch.optim as optim # 定义分类网络 class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.fc1 = nn.Linear(784, 100) self.fc2 = nn.Linear(100, 10) def forward(self, x): x = x.view(-1, 784) x = self.fc1(x) x = self.fc2(x) return x # 定义损失函数和优化器 criterion = nn.CrossEntropyLoss(weight=torch.tensor([0.1, 0.9])) # 根据样本数量设置权重 optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) # 训练模型 for epoch in range(10): running_loss = 0.0 for i, data in enumerate(trainloader, 0): inputs, labels = data optimizer.zero_grad() outputs = net(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() if i % 2000 == 1999: print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 2000)) running_loss = 0.0 print('Finished Training')
Im obigen Code werden die Gewichte zweier Klassen durch torch.tensor([0.1, 0.9])
angegeben, die Klasse mit der kleineren Anzahl der Proben Das Gewicht beträgt 0,1 und das Gewicht der Kategorien mit einer größeren Anzahl von Proben beträgt 0,9. Dadurch kann das Modell Kategorien mit einer geringeren Anzahl von Stichproben mehr Aufmerksamkeit schenken.
Referenzen:
[1] He, H. & Garcia, E. A. (2009). IEEE Transactions on Knowledge and Data Engineering, 21(9), 1263-1284.
[2] Chawla , N. V., Bowyer, K. W., Hall, L. O. & Kegelmeyer, W. P. (2002: Synthetic Minority Over-Sampling Technique, 16, 321-357.
).Das obige ist der detaillierte Inhalt vonKlassenungleichgewichtsproblem bei der Bildklassifizierung. Für weitere Informationen folgen Sie bitte anderen verwandten Artikeln auf der PHP chinesischen Website!