Rumah  >  Artikel  >  hujung hadapan web  >  怎样用 TensorFlow.js 创建基本的 AI 模型?

怎样用 TensorFlow.js 创建基本的 AI 模型?

青灯夜游
青灯夜游ke hadapan
2020-11-10 17:54:492962semak imbas

怎样用 TensorFlow.js 创建基本的 AI 模型?

在本文中我们来研究怎样用 TensorFlow.js 创建基本的 AI 模型,并使用更复杂的模型实现一些有趣的功能。我只是刚刚开始接触人工智能,尽管不需要深入的人工智能知识,但还是需要搞清楚一些概念才行。

什么是模型?

真实世界是很复杂的,我们需要对其进行简化才能理解,可以用通过模型来进行简化,这种模型有很多种:比如世界地图,或者图表等。

1.jpg

比如要建立一个用来表示房子出租价格与房屋面积关系的模型:首先要收集一些数据:

房间数量 价格
3 131000
3 125000
4 235000
4 265000
5 535000

然后,把这些数据显示在二维图形上,把每个参数(价格,房间数量)都做为 1 个维度:

2.gif

然后我们可以画一条线,并预测 更多房间的房屋出租价格。这种模型被称为线性回归,它是机器学习中最简单的模型之一。不过这个模型还不够好:

  1. 只有 5 个数据,所以不够可靠。
  2. 只有 2 个参数(价格,房间),但是还有更多可能会影响价格的因素:比如地区、装修情况等。

可以通过添加更多的数据来解决第一个问题,比如一百万个。对于第二个问题,可以添加更多维度。在二维图表中可以很容易理解数据并画一条线,在三维图中可以使用平面:

3.jpeg

但是当数据的维度是三维呢四维甚至是 1000000 维的时候,大脑就没有办法在图表上对其进行可视化了,但是可以在维度超过三维时通过数学来计算超平面,而神经网络就是为了解决这个问题而生的。

什么是神经网络?

要解什么是神经网络,需要知道什么是神经元。真正的神经元看上去是这样的:

4.gif

神经元由以下几部分组成:

  • 树突:这是数据的输入端。
  • 轴突:这是输出端。
  • 突触(未在图中表示):该结构允许一个神经元与另一个神经元之间进行通信。它负责在轴突的神经末梢和附近神经元的树突之间传递电信号。这些突触是学习的关键,因为它们会根据用途增减电活动。

机器学习中的神经元(简化):

5.jpg

  • Inputs(输入) :输入的参数。
  • Weights(权重) :像突触一样,用来通过调节神经元更好的建立线性回归。
  • Linear function(线性函数) :每个神经元就像一个线性回归函数,对于线性回归模型,只需要一个神经元够了。
  • Activation function(激活函数) :可以用一些激活函数来将输出从标量改为另一个非线性函数。常见的有 sigmoid、RELU 和 tanh。
  • Output(输出) :应用激活函数后的计算输出。

激活函数是非常有用的,神经网络的强大主要归功于它。假如没有任何激活功能,就不可能得到智能的神经元网络。因为尽管你的神经网络中有多个神经元,但神经网络的输出始终将是线性回归。所以需要一些机制来将各个线性回归变形为非线性的来解决非线性问题。通过激活函数可以将这些线性函数转换为非线性函数:

6.jpg

训练模型

正如 2D 线性回归的例子所描述的,只需要在图中画一条线就可以预测新数据了。尽管如此,“深度学习”的思想是让我们的神经网络学会画这条线。对于一条简单的线,可以用只有一个神经元的非常简单的神经网络即可,但是对于想要做更复杂事情的模型,例如对两组数据进行分类这种操作,需要通过“训练”使网络学习怎样得到下面的内容:

7.png

这个过程并不复杂,因为它是二维的。每个模型都用来描述一个世界,但是“训练”的概念在所有模型中都非常相似。第一步是绘制一条随机线,并在算法中通过迭代对其进行改进,每次迭代中过程中修正错误。这种优化算法名为 Gradient Descent(梯度下降)(有着相同概念的算法还有更复杂的 SGD 或 ADAM 等)。每种算法(线性回归,对数回归等)都有不同的成本函数来度量误差,成本函数会始终收敛于某个点。它可以是凸函数或凹函数,但是最终要收敛在 0% 误差的点上。我们的目标就是实现这一点。

8.png

当使用梯度下降算法时,先从其成本函数的某个随机点开始,但是我们不知道它究竟在什么地方!这就像你被蒙着眼睛丢在一座山上,想要下山的话必须一步一步地走到最低点。如果地形是不规则的(例如凹函数),则下降会更加复杂。

在这里不会深入解释“梯度下降”算法,只需要记住这是训练 AI 模型过程中最小化预测误差的优化算法就足够了。这种算法需要大量的时间和 GPU 进行矩阵乘法。通常在第一次执行时很难达到这个收敛点,因此需要修正一些超参数,例如学习率(learning rate)或添加正则化(regularization)。在梯度下降迭代之后,当误差接近 0% 时,会接近收敛点。这样就创建了模型,接下来就能够进行预测了。

9.gif

用 TensorFlow.js 训练模型

TensorFlow.js 提供了一种创建神经网络的简便方法。首先用 trainModel 方法创建一个 LinearModel 类。我们将使用顺序模型。顺序模型是其中一层的输出是下一层的输入的模型,即当模型拓扑是简单的层级结构,没有分支或跳过。在 trainModel 方法内部定义层(我们仅使用一层,因为它足以解决线性回归问题):

import * as tf from '@tensorflow/tfjs';

/**
* 线性模型类
*/
export default class LinearModel {
  /**
 * 训练模型
 */
  async trainModel(xs, ys){
    const layers = tf.layers.dense({
      units: 1, // 输出空间的纬度
      inputShape: [1], // 只有一个参数
    });
    const lossAndOptimizer = {
      loss: 'meanSquaredError',
      optimizer: 'sgd', // 随机梯度下降
    };

    this.linearModel = tf.sequential();
    this.linearModel.add(layers); // 添加一层
    this.linearModel.compile(lossAndOptimizer);

    // 开始模型训练
    await this.linearModel.fit(
      tf.tensor1d(xs),
      tf.tensor1d(ys),
    );
  }

  //...
}

使用这个类进行训练:

const model = new LinearModel()

// xs 与 ys 是 数组成员(x-axis 与 y-axis)
await model.trainModel(xs, ys)

训练结束后就可以开始预测了。

用 TensorFlow.js 进行预测

尽管在训练模型时需要事先定义一些超参数,但是进行一般的预测还是很容易的。通过下面的代码就够了:

import * as tf from '@tensorflow/tfjs';

export default class LinearModel {
  ... //前面训练模型的代码

  predict(value){
    return Array.from(
      this.linearModel
      .predict(tf.tensor2d([value], [1, 1]))
      .dataSync()
    )
  }
}

现在就可以预测了:

const prediction = model.predict(500) // 预测数字 500
console.log(prediction) // => 420.423

10.gif

在 TensorFlow.js 中使用预训练的模型

训练模型是最难的部分。首先对数据进行标准化来进行训练,还需要正确的设定所有超参数等等。对于咱们初学者,可以直接用那些预先训练好的模型。 TensorFlow.js 可以使用很多预训练的模型,还可以导入使用 TensorFlow 或 Keras 创建的外部模型。例如可以直接用 posenet 模型(实时人体姿态评估)做一些有意思的项目:

11.gif

这个 Demo 的代码:https://github.com/aralroca/posenet-d3

它用起来很容易:

import * as posenet from '@tensorflow-models/posenet'

// 设置一些常数
const imageScaleFactor = 0.5
const outputStride = 16
const flipHorizontal = true
const weight = 0.5

// 加载模型
const net = await posenet.load(weight)

// 进行预测
const poses = await net.estimateSinglePose(
  imageElement,
  imageScaleFactor,
  flipHorizontal,
  outputStride
)

这个 JSON 是 pose 变量:

{
  "score": 0.32371445304906,
  "keypoints": [
    {
      "position": {
        "y": 76.291801452637,
        "x": 253.36747741699
      },
      "part": "nose",
      "score": 0.99539834260941
    },
    {
      "position": {
        "y": 71.10383605957,
        "x": 253.54365539551
      },
      "part": "leftEye",
      "score": 0.98781454563141
    }
    // 后面还有: rightEye, leftEar, rightEar, leftShoulder, rightShoulder
    // leftElbow, rightElbow, leftWrist, rightWrist, leftHip, rightHip,
    // leftKnee, rightKnee, leftAnkle, rightAnkle...
  ]
}

从官方的 demo 可以看得到,用这个模型可以开发出很多有趣的项目。

怎样用 TensorFlow.js 创建基本的 AI 模型?

这个项目的源代码:https://github.com/aralroca/fishFollow-posenet-tfjs

导入 Keras 模型

可以把外部模型导入 TensorFlow.js。下面是一个用 Keras 模型(h5格式)进行数字识别的程序。首先要用 tfjs_converter 对模型的格式进行转换。

pip install tensorflowjs

使用转换器:

tensorflowjs_converter --input_format keras keras/cnn.h5 src/assets

最后,把模型导入到 JS 代码中:

// 载入模型
const model = await tf.loadModel('./assets/model.json')

// 准备图片
let img = tf.fromPixels(imageData, 1)
img = img.reshape([1, 28, 28, 1])
img = tf.cast(img, 'float32')

// 进行预测
const output = model.predict(img)

只需要几行代码行就完成了。当然还可以在代码中添加更多的逻辑来实现更多功能,例如可以把数字写在 canvas 上,然后得到其图像来进行预测。

13.gif

这个项目的源代码:https://github.com/aralroca/MNIST_React_TensorFlowJS

为什么要用在浏览器中?

由于设备的不同,在浏览器中训练模型时,效率可能很低。用 TensorFlow.js 利用 WebGL 在后台训练模型,比用 Python 版的 TensorFlow 慢 1.5 ~ 2倍。

但是在 TensorFlow.js 出现之前,没有能直接在浏览器中使用机器学习模型的 API,现在则可以在浏览器应用中离线训练和使用模型。而且预测速度更快,因为不需要向服务器发送请求。另一个好处是成本低,因为所有这些计算都是在客户端完成的。

总结

  • 模型是表示现实世界的一种简化方式,可以使用它来进行预测。
  • 可以用神经网络创建模型。
  • TensorFlow.js 是创建神经网络的简便工具。

英文原文地址:https://aralroca.com/blog/first-steps-with-tensorflowjs

作者:Aral Roca

更多编程相关知识,请访问:编程课程!!

Atas ialah kandungan terperinci 怎样用 TensorFlow.js 创建基本的 AI 模型?. Untuk maklumat lanjut, sila ikut artikel berkaitan lain di laman web China PHP!

Kenyataan:
Artikel ini dikembalikan pada:segmentfault.com. Jika ada pelanggaran, sila hubungi admin@php.cn Padam
Artikel sebelumnya:深入了解JavaScript中的nullArtikel seterusnya:jQuery和Zepto是什么?