搜索
首页后端开发Python教程为什么在小数据集上微调 MLP 模型,仍然保持与预训练权重相同的测试精度?

为什么在小数据集上微调 MLP 模型,仍然保持与预训练权重相同的测试精度?

问题内容

我设计了一个简单的 mlp 模型,在 6k 数据样本上进行训练。

class mlp(nn.module):
    def __init__(self,input_dim=92, hidden_dim = 150, num_classes=2):
        super().__init__()
        self.input_dim = input_dim
        self.num_classes = num_classes
        self.hidden_dim = hidden_dim
        #self.softmax = nn.softmax(dim=1)

        self.layers = nn.sequential(
            nn.linear(self.input_dim, self.hidden_dim),
            nn.relu(),
            nn.linear(self.hidden_dim, self.hidden_dim),
            nn.relu(),
            nn.linear(self.hidden_dim, self.hidden_dim),
            nn.relu(),
            nn.linear(self.hidden_dim, self.num_classes),

        )

    def forward(self, x):
        x = self.layers(x)
        return x

并且模型已实例化

model = mlp(input_dim=input_dim, hidden_dim=hidden_dim, num_classes=num_classes).to(device)

optimizer = optimizer.adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
criterion = nn.crossentropyloss()

和超参数:

num_epoch = 300   # 200e3//len(train_loader)
learning_rate = 1e-3
batch_size = 64
device = torch.device("cuda")
seed = 42
torch.manual_seed(42)

我的实现主要遵循这个问题。我将模型保存为预训练权重 model_weights.pth

model在测试数据集上的准确率是96.80%

然后,我还有另外 50 个样本(在 finetune_loader 中),我正在尝试在这 50 个样本上微调模型:

model_finetune = MLP()
model_finetune.load_state_dict(torch.load('model_weights.pth'))
model_finetune.to(device)
model_finetune.train()
# train the network
for t in tqdm(range(num_epoch)):
  for i, data in enumerate(finetune_loader, 0):
    #def closure():
      # Get and prepare inputs
      inputs, targets = data
      inputs, targets = inputs.float(), targets.long()
      inputs, targets = inputs.to(device), targets.to(device)
      
      # Zero the gradients
      optimizer.zero_grad()
      # Perform forward pass
      outputs = model_finetune(inputs)
      # Compute loss
      loss = criterion(outputs, targets)
      # Perform backward pass
      loss.backward()
      #return loss
      optimizer.step()     # a

model_finetune.eval()
with torch.no_grad():
    outputs2 = model_finetune(test_data)
    #predicted_labels = outputs.squeeze().tolist()

    _, preds = torch.max(outputs2, 1)
    prediction_test = np.array(preds.cpu())
    accuracy_test_finetune = accuracy_score(y_test, prediction_test)
    accuracy_test_finetune
    
    Output: 0.9680851063829787

我检查过,精度与将模型微调到 50 个样本之前保持不变,并且输出概率也相同。

可能是什么原因?我在微调代码中是否犯了一些错误?


正确答案


您必须使用新模型(model_finetune 对象)重新初始化优化器。目前,正如我在您的代码中看到的那样,它似乎仍然使用使用旧模型权重初始化的优化器 - model.parameters()。

以上是为什么在小数据集上微调 MLP 模型,仍然保持与预训练权重相同的测试精度?的详细内容。更多信息请关注PHP中文网其他相关文章!

声明
本文转载于:stackoverflow。如有侵权,请联系admin@php.cn删除
Python:游戏,Guis等Python:游戏,Guis等Apr 13, 2025 am 12:14 AM

Python在游戏和GUI开发中表现出色。1)游戏开发使用Pygame,提供绘图、音频等功能,适合创建2D游戏。2)GUI开发可选择Tkinter或PyQt,Tkinter简单易用,PyQt功能丰富,适合专业开发。

Python vs.C:申请和用例Python vs.C:申请和用例Apr 12, 2025 am 12:01 AM

Python适合数据科学、Web开发和自动化任务,而C 适用于系统编程、游戏开发和嵌入式系统。 Python以简洁和强大的生态系统着称,C 则以高性能和底层控制能力闻名。

2小时的Python计划:一种现实的方法2小时的Python计划:一种现实的方法Apr 11, 2025 am 12:04 AM

2小时内可以学会Python的基本编程概念和技能。1.学习变量和数据类型,2.掌握控制流(条件语句和循环),3.理解函数的定义和使用,4.通过简单示例和代码片段快速上手Python编程。

Python:探索其主要应用程序Python:探索其主要应用程序Apr 10, 2025 am 09:41 AM

Python在web开发、数据科学、机器学习、自动化和脚本编写等领域有广泛应用。1)在web开发中,Django和Flask框架简化了开发过程。2)数据科学和机器学习领域,NumPy、Pandas、Scikit-learn和TensorFlow库提供了强大支持。3)自动化和脚本编写方面,Python适用于自动化测试和系统管理等任务。

您可以在2小时内学到多少python?您可以在2小时内学到多少python?Apr 09, 2025 pm 04:33 PM

两小时内可以学到Python的基础知识。1.学习变量和数据类型,2.掌握控制结构如if语句和循环,3.了解函数的定义和使用。这些将帮助你开始编写简单的Python程序。

如何在10小时内通过项目和问题驱动的方式教计算机小白编程基础?如何在10小时内通过项目和问题驱动的方式教计算机小白编程基础?Apr 02, 2025 am 07:18 AM

如何在10小时内教计算机小白编程基础?如果你只有10个小时来教计算机小白一些编程知识,你会选择教些什么�...

如何在使用 Fiddler Everywhere 进行中间人读取时避免被浏览器检测到?如何在使用 Fiddler Everywhere 进行中间人读取时避免被浏览器检测到?Apr 02, 2025 am 07:15 AM

使用FiddlerEverywhere进行中间人读取时如何避免被检测到当你使用FiddlerEverywhere...

Python 3.6加载Pickle文件报错"__builtin__"模块未找到怎么办?Python 3.6加载Pickle文件报错"__builtin__"模块未找到怎么办?Apr 02, 2025 am 07:12 AM

Python3.6环境下加载Pickle文件报错:ModuleNotFoundError:Nomodulenamed...

See all articles

热AI工具

Undresser.AI Undress

Undresser.AI Undress

人工智能驱动的应用程序,用于创建逼真的裸体照片

AI Clothes Remover

AI Clothes Remover

用于从照片中去除衣服的在线人工智能工具。

Undress AI Tool

Undress AI Tool

免费脱衣服图片

Clothoff.io

Clothoff.io

AI脱衣机

AI Hentai Generator

AI Hentai Generator

免费生成ai无尽的。

热门文章

R.E.P.O.能量晶体解释及其做什么(黄色晶体)
3 周前By尊渡假赌尊渡假赌尊渡假赌
R.E.P.O.最佳图形设置
3 周前By尊渡假赌尊渡假赌尊渡假赌
R.E.P.O.如果您听不到任何人,如何修复音频
3 周前By尊渡假赌尊渡假赌尊渡假赌
WWE 2K25:如何解锁Myrise中的所有内容
4 周前By尊渡假赌尊渡假赌尊渡假赌

热工具

螳螂BT

螳螂BT

Mantis是一个易于部署的基于Web的缺陷跟踪工具,用于帮助产品缺陷跟踪。它需要PHP、MySQL和一个Web服务器。请查看我们的演示和托管服务。

记事本++7.3.1

记事本++7.3.1

好用且免费的代码编辑器

MinGW - 适用于 Windows 的极简 GNU

MinGW - 适用于 Windows 的极简 GNU

这个项目正在迁移到osdn.net/projects/mingw的过程中,你可以继续在那里关注我们。MinGW:GNU编译器集合(GCC)的本地Windows移植版本,可自由分发的导入库和用于构建本地Windows应用程序的头文件;包括对MSVC运行时的扩展,以支持C99功能。MinGW的所有软件都可以在64位Windows平台上运行。

PhpStorm Mac 版本

PhpStorm Mac 版本

最新(2018.2.1 )专业的PHP集成开发工具

SublimeText3汉化版

SublimeText3汉化版

中文版,非常好用