搜索
首页web前端js教程怎样用 TensorFlow.js 创建基本的 AI 模型?

怎样用 TensorFlow.js 创建基本的 AI 模型?

Nov 10, 2020 pm 05:54 PM
javascripttensorflow前端

怎样用 TensorFlow.js 创建基本的 AI 模型?

在本文中我们来研究怎样用 TensorFlow.js 创建基本的 AI 模型,并使用更复杂的模型实现一些有趣的功能。我只是刚刚开始接触人工智能,尽管不需要深入的人工智能知识,但还是需要搞清楚一些概念才行。

什么是模型?

真实世界是很复杂的,我们需要对其进行简化才能理解,可以用通过模型来进行简化,这种模型有很多种:比如世界地图,或者图表等。

1.jpg

比如要建立一个用来表示房子出租价格与房屋面积关系的模型:首先要收集一些数据:

房间数量 价格
3 131000
3 125000
4 235000
4 265000
5 535000

然后,把这些数据显示在二维图形上,把每个参数(价格,房间数量)都做为 1 个维度:

2.gif

然后我们可以画一条线,并预测 更多房间的房屋出租价格。这种模型被称为线性回归,它是机器学习中最简单的模型之一。不过这个模型还不够好:

  1. 只有 5 个数据,所以不够可靠。
  2. 只有 2 个参数(价格,房间),但是还有更多可能会影响价格的因素:比如地区、装修情况等。

可以通过添加更多的数据来解决第一个问题,比如一百万个。对于第二个问题,可以添加更多维度。在二维图表中可以很容易理解数据并画一条线,在三维图中可以使用平面:

3.jpeg

但是当数据的维度是三维呢四维甚至是 1000000 维的时候,大脑就没有办法在图表上对其进行可视化了,但是可以在维度超过三维时通过数学来计算超平面,而神经网络就是为了解决这个问题而生的。

什么是神经网络?

要解什么是神经网络,需要知道什么是神经元。真正的神经元看上去是这样的:

4.gif

神经元由以下几部分组成:

  • 树突:这是数据的输入端。
  • 轴突:这是输出端。
  • 突触(未在图中表示):该结构允许一个神经元与另一个神经元之间进行通信。它负责在轴突的神经末梢和附近神经元的树突之间传递电信号。这些突触是学习的关键,因为它们会根据用途增减电活动。

机器学习中的神经元(简化):

5.jpg

  • Inputs(输入) :输入的参数。
  • Weights(权重) :像突触一样,用来通过调节神经元更好的建立线性回归。
  • Linear function(线性函数) :每个神经元就像一个线性回归函数,对于线性回归模型,只需要一个神经元够了。
  • Activation function(激活函数) :可以用一些激活函数来将输出从标量改为另一个非线性函数。常见的有 sigmoid、RELU 和 tanh。
  • Output(输出) :应用激活函数后的计算输出。

激活函数是非常有用的,神经网络的强大主要归功于它。假如没有任何激活功能,就不可能得到智能的神经元网络。因为尽管你的神经网络中有多个神经元,但神经网络的输出始终将是线性回归。所以需要一些机制来将各个线性回归变形为非线性的来解决非线性问题。通过激活函数可以将这些线性函数转换为非线性函数:

6.jpg

训练模型

正如 2D 线性回归的例子所描述的,只需要在图中画一条线就可以预测新数据了。尽管如此,“深度学习”的思想是让我们的神经网络学会画这条线。对于一条简单的线,可以用只有一个神经元的非常简单的神经网络即可,但是对于想要做更复杂事情的模型,例如对两组数据进行分类这种操作,需要通过“训练”使网络学习怎样得到下面的内容:

7.png

这个过程并不复杂,因为它是二维的。每个模型都用来描述一个世界,但是“训练”的概念在所有模型中都非常相似。第一步是绘制一条随机线,并在算法中通过迭代对其进行改进,每次迭代中过程中修正错误。这种优化算法名为 Gradient Descent(梯度下降)(有着相同概念的算法还有更复杂的 SGD 或 ADAM 等)。每种算法(线性回归,对数回归等)都有不同的成本函数来度量误差,成本函数会始终收敛于某个点。它可以是凸函数或凹函数,但是最终要收敛在 0% 误差的点上。我们的目标就是实现这一点。

8.png

当使用梯度下降算法时,先从其成本函数的某个随机点开始,但是我们不知道它究竟在什么地方!这就像你被蒙着眼睛丢在一座山上,想要下山的话必须一步一步地走到最低点。如果地形是不规则的(例如凹函数),则下降会更加复杂。

在这里不会深入解释“梯度下降”算法,只需要记住这是训练 AI 模型过程中最小化预测误差的优化算法就足够了。这种算法需要大量的时间和 GPU 进行矩阵乘法。通常在第一次执行时很难达到这个收敛点,因此需要修正一些超参数,例如学习率(learning rate)或添加正则化(regularization)。在梯度下降迭代之后,当误差接近 0% 时,会接近收敛点。这样就创建了模型,接下来就能够进行预测了。

9.gif

用 TensorFlow.js 训练模型

TensorFlow.js 提供了一种创建神经网络的简便方法。首先用 trainModel 方法创建一个 LinearModel 类。我们将使用顺序模型。顺序模型是其中一层的输出是下一层的输入的模型,即当模型拓扑是简单的层级结构,没有分支或跳过。在 trainModel 方法内部定义层(我们仅使用一层,因为它足以解决线性回归问题):

import * as tf from '@tensorflow/tfjs';

/**
* 线性模型类
*/
export default class LinearModel {
  /**
 * 训练模型
 */
  async trainModel(xs, ys){
    const layers = tf.layers.dense({
      units: 1, // 输出空间的纬度
      inputShape: [1], // 只有一个参数
    });
    const lossAndOptimizer = {
      loss: 'meanSquaredError',
      optimizer: 'sgd', // 随机梯度下降
    };

    this.linearModel = tf.sequential();
    this.linearModel.add(layers); // 添加一层
    this.linearModel.compile(lossAndOptimizer);

    // 开始模型训练
    await this.linearModel.fit(
      tf.tensor1d(xs),
      tf.tensor1d(ys),
    );
  }

  //...
}

使用这个类进行训练:

const model = new LinearModel()

// xs 与 ys 是 数组成员(x-axis 与 y-axis)
await model.trainModel(xs, ys)

训练结束后就可以开始预测了。

用 TensorFlow.js 进行预测

尽管在训练模型时需要事先定义一些超参数,但是进行一般的预测还是很容易的。通过下面的代码就够了:

import * as tf from '@tensorflow/tfjs';

export default class LinearModel {
  ... //前面训练模型的代码

  predict(value){
    return Array.from(
      this.linearModel
      .predict(tf.tensor2d([value], [1, 1]))
      .dataSync()
    )
  }
}

现在就可以预测了:

const prediction = model.predict(500) // 预测数字 500
console.log(prediction) // => 420.423

10.gif

在 TensorFlow.js 中使用预训练的模型

训练模型是最难的部分。首先对数据进行标准化来进行训练,还需要正确的设定所有超参数等等。对于咱们初学者,可以直接用那些预先训练好的模型。 TensorFlow.js 可以使用很多预训练的模型,还可以导入使用 TensorFlow 或 Keras 创建的外部模型。例如可以直接用 posenet 模型(实时人体姿态评估)做一些有意思的项目:

11.gif

这个 Demo 的代码:https://github.com/aralroca/posenet-d3

它用起来很容易:

import * as posenet from '@tensorflow-models/posenet'

// 设置一些常数
const imageScaleFactor = 0.5
const outputStride = 16
const flipHorizontal = true
const weight = 0.5

// 加载模型
const net = await posenet.load(weight)

// 进行预测
const poses = await net.estimateSinglePose(
  imageElement,
  imageScaleFactor,
  flipHorizontal,
  outputStride
)

这个 JSON 是 pose 变量:

{
  "score": 0.32371445304906,
  "keypoints": [
    {
      "position": {
        "y": 76.291801452637,
        "x": 253.36747741699
      },
      "part": "nose",
      "score": 0.99539834260941
    },
    {
      "position": {
        "y": 71.10383605957,
        "x": 253.54365539551
      },
      "part": "leftEye",
      "score": 0.98781454563141
    }
    // 后面还有: rightEye, leftEar, rightEar, leftShoulder, rightShoulder
    // leftElbow, rightElbow, leftWrist, rightWrist, leftHip, rightHip,
    // leftKnee, rightKnee, leftAnkle, rightAnkle...
  ]
}

从官方的 demo 可以看得到,用这个模型可以开发出很多有趣的项目。

怎样用 TensorFlow.js 创建基本的 AI 模型?

这个项目的源代码:https://github.com/aralroca/fishFollow-posenet-tfjs

导入 Keras 模型

可以把外部模型导入 TensorFlow.js。下面是一个用 Keras 模型(h5格式)进行数字识别的程序。首先要用 tfjs_converter 对模型的格式进行转换。

pip install tensorflowjs

使用转换器:

tensorflowjs_converter --input_format keras keras/cnn.h5 src/assets

最后,把模型导入到 JS 代码中:

// 载入模型
const model = await tf.loadModel('./assets/model.json')

// 准备图片
let img = tf.fromPixels(imageData, 1)
img = img.reshape([1, 28, 28, 1])
img = tf.cast(img, 'float32')

// 进行预测
const output = model.predict(img)

只需要几行代码行就完成了。当然还可以在代码中添加更多的逻辑来实现更多功能,例如可以把数字写在 canvas 上,然后得到其图像来进行预测。

13.gif

这个项目的源代码:https://github.com/aralroca/MNIST_React_TensorFlowJS

为什么要用在浏览器中?

由于设备的不同,在浏览器中训练模型时,效率可能很低。用 TensorFlow.js 利用 WebGL 在后台训练模型,比用 Python 版的 TensorFlow 慢 1.5 ~ 2倍。

但是在 TensorFlow.js 出现之前,没有能直接在浏览器中使用机器学习模型的 API,现在则可以在浏览器应用中离线训练和使用模型。而且预测速度更快,因为不需要向服务器发送请求。另一个好处是成本低,因为所有这些计算都是在客户端完成的。

总结

  • 模型是表示现实世界的一种简化方式,可以使用它来进行预测。
  • 可以用神经网络创建模型。
  • TensorFlow.js 是创建神经网络的简便工具。

英文原文地址:https://aralroca.com/blog/first-steps-with-tensorflowjs

作者:Aral Roca

更多编程相关知识,请访问:编程课程!!

以上是怎样用 TensorFlow.js 创建基本的 AI 模型?的详细内容。更多信息请关注PHP中文网其他相关文章!

声明
本文转载于:segmentfault。如有侵权,请联系admin@php.cn删除
从C/C到JavaScript:所有工作方式从C/C到JavaScript:所有工作方式Apr 14, 2025 am 12:05 AM

从C/C 转向JavaScript需要适应动态类型、垃圾回收和异步编程等特点。1)C/C 是静态类型语言,需手动管理内存,而JavaScript是动态类型,垃圾回收自动处理。2)C/C 需编译成机器码,JavaScript则为解释型语言。3)JavaScript引入闭包、原型链和Promise等概念,增强了灵活性和异步编程能力。

JavaScript引擎:比较实施JavaScript引擎:比较实施Apr 13, 2025 am 12:05 AM

不同JavaScript引擎在解析和执行JavaScript代码时,效果会有所不同,因为每个引擎的实现原理和优化策略各有差异。1.词法分析:将源码转换为词法单元。2.语法分析:生成抽象语法树。3.优化和编译:通过JIT编译器生成机器码。4.执行:运行机器码。V8引擎通过即时编译和隐藏类优化,SpiderMonkey使用类型推断系统,导致在相同代码上的性能表现不同。

超越浏览器:现实世界中的JavaScript超越浏览器:现实世界中的JavaScriptApr 12, 2025 am 12:06 AM

JavaScript在现实世界中的应用包括服务器端编程、移动应用开发和物联网控制:1.通过Node.js实现服务器端编程,适用于高并发请求处理。2.通过ReactNative进行移动应用开发,支持跨平台部署。3.通过Johnny-Five库用于物联网设备控制,适用于硬件交互。

使用Next.js(后端集成)构建多租户SaaS应用程序使用Next.js(后端集成)构建多租户SaaS应用程序Apr 11, 2025 am 08:23 AM

我使用您的日常技术工具构建了功能性的多租户SaaS应用程序(一个Edtech应用程序),您可以做同样的事情。 首先,什么是多租户SaaS应用程序? 多租户SaaS应用程序可让您从唱歌中为多个客户提供服务

如何使用Next.js(前端集成)构建多租户SaaS应用程序如何使用Next.js(前端集成)构建多租户SaaS应用程序Apr 11, 2025 am 08:22 AM

本文展示了与许可证确保的后端的前端集成,并使用Next.js构建功能性Edtech SaaS应用程序。 前端获取用户权限以控制UI的可见性并确保API要求遵守角色库

JavaScript:探索网络语言的多功能性JavaScript:探索网络语言的多功能性Apr 11, 2025 am 12:01 AM

JavaScript是现代Web开发的核心语言,因其多样性和灵活性而广泛应用。1)前端开发:通过DOM操作和现代框架(如React、Vue.js、Angular)构建动态网页和单页面应用。2)服务器端开发:Node.js利用非阻塞I/O模型处理高并发和实时应用。3)移动和桌面应用开发:通过ReactNative和Electron实现跨平台开发,提高开发效率。

JavaScript的演变:当前的趋势和未来前景JavaScript的演变:当前的趋势和未来前景Apr 10, 2025 am 09:33 AM

JavaScript的最新趋势包括TypeScript的崛起、现代框架和库的流行以及WebAssembly的应用。未来前景涵盖更强大的类型系统、服务器端JavaScript的发展、人工智能和机器学习的扩展以及物联网和边缘计算的潜力。

神秘的JavaScript:它的作用以及为什么重要神秘的JavaScript:它的作用以及为什么重要Apr 09, 2025 am 12:07 AM

JavaScript是现代Web开发的基石,它的主要功能包括事件驱动编程、动态内容生成和异步编程。1)事件驱动编程允许网页根据用户操作动态变化。2)动态内容生成使得页面内容可以根据条件调整。3)异步编程确保用户界面不被阻塞。JavaScript广泛应用于网页交互、单页面应用和服务器端开发,极大地提升了用户体验和跨平台开发的灵活性。

See all articles

热AI工具

Undresser.AI Undress

Undresser.AI Undress

人工智能驱动的应用程序,用于创建逼真的裸体照片

AI Clothes Remover

AI Clothes Remover

用于从照片中去除衣服的在线人工智能工具。

Undress AI Tool

Undress AI Tool

免费脱衣服图片

Clothoff.io

Clothoff.io

AI脱衣机

AI Hentai Generator

AI Hentai Generator

免费生成ai无尽的。

热门文章

R.E.P.O.能量晶体解释及其做什么(黄色晶体)
3 周前By尊渡假赌尊渡假赌尊渡假赌
R.E.P.O.最佳图形设置
3 周前By尊渡假赌尊渡假赌尊渡假赌
R.E.P.O.如果您听不到任何人,如何修复音频
3 周前By尊渡假赌尊渡假赌尊渡假赌
WWE 2K25:如何解锁Myrise中的所有内容
1 个月前By尊渡假赌尊渡假赌尊渡假赌

热工具

SublimeText3 Linux新版

SublimeText3 Linux新版

SublimeText3 Linux最新版

适用于 Eclipse 的 SAP NetWeaver 服务器适配器

适用于 Eclipse 的 SAP NetWeaver 服务器适配器

将Eclipse与SAP NetWeaver应用服务器集成。

SublimeText3汉化版

SublimeText3汉化版

中文版,非常好用

Dreamweaver Mac版

Dreamweaver Mac版

视觉化网页开发工具

DVWA

DVWA

Damn Vulnerable Web App (DVWA) 是一个PHP/MySQL的Web应用程序,非常容易受到攻击。它的主要目标是成为安全专业人员在合法环境中测试自己的技能和工具的辅助工具,帮助Web开发人员更好地理解保护Web应用程序的过程,并帮助教师/学生在课堂环境中教授/学习Web应用程序安全。DVWA的目标是通过简单直接的界面练习一些最常见的Web漏洞,难度各不相同。请注意,该软件中