首頁  >  文章  >  web前端  >  怎麼用 TensorFlow.js 建立基本的 AI 模型?

怎麼用 TensorFlow.js 建立基本的 AI 模型?

青灯夜游
青灯夜游轉載
2020-11-10 17:54:492962瀏覽

怎麼用 TensorFlow.js 建立基本的 AI 模型?

在本文中我們來研究如何用 TensorFlow.js 建立基本的 AI 模型,並使用更複雜的模型實作一些有趣的功能。我只是剛開始接觸人工智慧,儘管不需要深入的人工智慧知識,但還是需要搞清楚一些概念才行。

什麼是模型?

真實世界是很複雜的,我們需要對其進行簡化才能理解,可以用透過模型來簡化,這種模型有很多種:例如世界地圖,或圖表等。

怎麼用 TensorFlow.js 建立基本的 AI 模型?

例如要建立一個用來表示房子出租價格與房屋面積關係的模型:首先要收集一些資料:

房間數 價格
3 131000
3 125000
4 #235000
4 #265000
5 535000
#

然後,把這些數據顯示在二維圖形上,把每個參數(價格,房間數量)都做為1 個維度:

怎麼用 TensorFlow.js 建立基本的 AI 模型?

##然後我們可以畫一條線,並預測更多房間的房屋出租價格。這種模型被稱為線性迴歸,它是機器學習中最簡單的模型之一。不過這個模型還不夠好:

    只有 5 個數據,所以不夠可靠。
  1. 只有 2 個參數(價格,房間),但還有更多可能會影響價格的因素:例如地區、裝潢情況等。
可以透過增加更多的資料來解決第一個問題,例如一百萬個。對於第二個問題,可以新增更多維度。在二維圖表中可以很容易理解資料並畫一條線,在三維圖中可以使用平面:

怎麼用 TensorFlow.js 建立基本的 AI 模型?

但是當資料的維度是三維呢四維甚至是1000000維的時候,大腦就沒有辦法在圖表上對其進行可視化了,但是可以在維度超過三維時通過數學來計算超平面,而神經網絡就是為了解決這個問題而生的。

什麼是神經網路?

要解什麼是神經網絡,需要知道什麼是神經元。真正的神經元看起來是這樣的:

怎麼用 TensorFlow.js 建立基本的 AI 模型?

神經元由以下幾部分組成:

    ##樹突
  • :這是資料的輸入端。
  • 軸突
  • :這是輸出端。
  • 突觸
  • (未在圖中表示):該結構允許一個神經元與另一個神經元之間進行通訊。它負責在軸突的神經末梢和附近神經元的樹突之間傳遞電訊號。這些突觸是學習的關鍵,因為它們會根據用途增減電活動。
  • 機器學習中的神經元(簡化):

怎麼用 TensorFlow.js 建立基本的 AI 模型?

    #Inputs(輸入)
  • :輸入的參數。
  • Weights(權重)
  • :像突觸一樣,用來透過調節神經元更好的建立線性迴歸。
  • Linear function(線性函數)
  • :每個神經元就像線性迴歸函數,對於線性迴歸模型,只需要一個神經元夠了。
  • Activation function(激活函數)
  • :可以用一些激活函數來將輸出從標量改為另一個非線性函數。常見的有 sigmoid、RELU 和 tanh。
  • Output(輸出)
  • :套用激活函數後的計算輸出。
  • 激活函數是非常有用的,神經網路的強大主要歸功於它。要沒有任何活化功能,就不可能得到智慧的神經元網路。因為儘管你的神經網路中有多個神經元,但神經網路的輸出始終將是線性迴歸。所以需要一些機制來將各個線性迴歸變形為非線性的來解決非線性問題。透過激活函數可以將這些線性函數轉換為非線性函數:

怎麼用 TensorFlow.js 建立基本的 AI 模型?訓練模型

正如2D 線性迴歸的例子所描述的,只需要在圖中畫一條線就可以預測新數據了。儘管如此,「深度學習」的想法是讓我們的神經網路學會畫這條線。對於一條簡單的線,可以用只有一個神經元的非常簡單的神經網絡即可,但是對於想要做更複雜事情的模型,例如對兩組數據進行分類這種操作,需要通過“訓練”使網絡學習怎麼得到下面的內容:

怎麼用 TensorFlow.js 建立基本的 AI 模型?這個過程並不複雜,因為它是二維的。每個模型都用來描述一個世界,但是「訓練」的概念在所有模型中都非常相似。第一步是繪製一條隨機線,並在演算法中透過迭代對其進行改進,每次迭代中過程中修正錯誤。這種最佳化演算法名為 Gradient Descent(梯度下降)(有著相同概念的演算法還有更複雜的 SGD 或 ADAM 等)。每種演算法(線性迴歸,對數迴歸等)都有不同的成本函數來度量誤差,成本函數會始終收斂於某一點。它可以是凸函數或凹函數,但是最終要收斂在 0% 誤差的點上。我們的目標就是實現這一點。

當使用梯度下降演算法時,先從其成本函數的某個隨機點開始,但是我們不知道它究竟在什麼地方!這就像你被蒙著眼睛丟在一座山上,想要下山的話必須一步一步走到最低點。如果地形是不規則的(例如凹函數),則下降會更加複雜。

在這裡不會深入解釋「梯度下降」演算法,只需要記住這是訓練 AI 模型過程中最小化預測誤差的最佳化演算法就足夠了。這種演算法需要大量的時間和 GPU 進行矩陣乘法。通常在第一次執行時很難達到這個收斂點,因此需要修正一些超參數,例如學習率(learning rate)或添加正規化(regularization)。在梯度下降迭代之後,當誤差接近 0% 時,會接近收斂點。這樣就創建了模型,接下來就能夠進行預測了。

怎麼用 TensorFlow.js 建立基本的 AI 模型?

用 TensorFlow.js 訓練模型

TensorFlow.js 提供了一個建立神經網路的簡單方法。首先用 trainModel 方法建立一個 LinearModel 類別。我們將使用順序模型。順序模型是其中一層的輸出是下一層的輸入的模型,即當模型拓撲是簡單的層級結構,沒有分支或跳過。在trainModel 方法內部定義層(我們只使用一層,因為它足以解決線性迴歸問題):

import * as tf from '@tensorflow/tfjs';

/**
* 线性模型类
*/
export default class LinearModel {
  /**
 * 训练模型
 */
  async trainModel(xs, ys){
    const layers = tf.layers.dense({
      units: 1, // 输出空间的纬度
      inputShape: [1], // 只有一个参数
    });
    const lossAndOptimizer = {
      loss: 'meanSquaredError',
      optimizer: 'sgd', // 随机梯度下降
    };

    this.linearModel = tf.sequential();
    this.linearModel.add(layers); // 添加一层
    this.linearModel.compile(lossAndOptimizer);

    // 开始模型训练
    await this.linearModel.fit(
      tf.tensor1d(xs),
      tf.tensor1d(ys),
    );
  }

  //...
}

使用這個類別進行訓練:

const model = new LinearModel()

// xs 与 ys 是 数组成员(x-axis 与 y-axis)
await model.trainModel(xs, ys)

訓練結束後就可以開始預測了。

用 TensorFlow.js 進行預測

儘管在訓練模型時需要事先定義一些超參數,但是進行一般的預測還是很容易的。透過下面的程式碼就夠了:

import * as tf from '@tensorflow/tfjs';

export default class LinearModel {
  ... //前面训练模型的代码

  predict(value){
    return Array.from(
      this.linearModel
      .predict(tf.tensor2d([value], [1, 1]))
      .dataSync()
    )
  }
}

現在就可以預測了:

const prediction = model.predict(500) // 预测数字 500
console.log(prediction) // => 420.423

怎麼用 TensorFlow.js 建立基本的 AI 模型?

在TensorFlow.js 中使用預先訓練的模型

訓練模型是最困難的部分。首先將資料標準化來進行訓練,還需要正確的設定所有超參數等等。對於咱們初學者,可以直接用那些預先訓練好的模型。 TensorFlow.js 可以使用許多預先訓練的模型,也可以匯入使用 TensorFlow 或 Keras 建立的外部模型。例如可以直接用posenet 模型(即時人體姿態評估)做一些有趣的專案:

怎麼用 TensorFlow.js 建立基本的 AI 模型?

這個Demo 的程式碼:https://github.com/aralroca/posenet- d3

它用起來很容易:

import * as posenet from '@tensorflow-models/posenet'

// 设置一些常数
const imageScaleFactor = 0.5
const outputStride = 16
const flipHorizontal = true
const weight = 0.5

// 加载模型
const net = await posenet.load(weight)

// 进行预测
const poses = await net.estimateSinglePose(
  imageElement,
  imageScaleFactor,
  flipHorizontal,
  outputStride
)

這個JSON 是pose 變數:

{
  "score": 0.32371445304906,
  "keypoints": [
    {
      "position": {
        "y": 76.291801452637,
        "x": 253.36747741699
      },
      "part": "nose",
      "score": 0.99539834260941
    },
    {
      "position": {
        "y": 71.10383605957,
        "x": 253.54365539551
      },
      "part": "leftEye",
      "score": 0.98781454563141
    }
    // 后面还有: rightEye, leftEar, rightEar, leftShoulder, rightShoulder
    // leftElbow, rightElbow, leftWrist, rightWrist, leftHip, rightHip,
    // leftKnee, rightKnee, leftAnkle, rightAnkle...
  ]
}

從官方的demo 可以看得到,用這個模型可以開發出很多有趣的項目。

怎麼用 TensorFlow.js 建立基本的 AI 模型?

這個專案的原始碼:https://github.com/aralroca/fishFollow-posenet-tfjs

匯入Keras 模型

可以把外部模型導入TensorFlow.js。下面是一個用 Keras 模型(h5格式)進行數字辨識的程式。首先要用 tfjs_converter 對模型的格式進行轉換。

pip install tensorflowjs

使用轉換器:

tensorflowjs_converter --input_format keras keras/cnn.h5 src/assets

最後,把模型匯入到 JS 程式碼:

// 载入模型
const model = await tf.loadModel('./assets/model.json')

// 准备图片
let img = tf.fromPixels(imageData, 1)
img = img.reshape([1, 28, 28, 1])
img = tf.cast(img, 'float32')

// 进行预测
const output = model.predict(img)

只需要幾行程式碼行就完成了。當然還可以在程式碼中加入更多的邏輯來實現更多功能,例如可以把數字寫在 canvas 上,然後得到其圖像來進行預測。

怎麼用 TensorFlow.js 建立基本的 AI 模型?

這個專案的原始碼:https://github.com/aralroca/MNIST_React_TensorFlowJS

為什麼要用在瀏覽器裡?

由於裝置的不同,在瀏覽器中訓練模型時,效率可能很低。用 TensorFlow.js 利用 WebGL 在後台訓練模型,比用 Python 版的 TensorFlow 慢 1.5 ~ 2倍。

但是在 TensorFlow.js 出現之前,沒有能直接在瀏覽器中使用機器學習模型的 API,現在則可以在瀏覽器應用程式中離線訓練和使用模型。而且預測速度更快,因為不需要向伺服器發送請求。另一個好處是成本低,因為所有這些計算都是在客戶端完成的。

總結

  • 模型是表示現實世界的簡化方式,可以用它來進行預測。
  • 可以用神經網路建立模型。
  • TensorFlow.js 是建立神經網路的簡單工具。

英文原文網址:https://aralroca.com/blog/first-steps-with-tensorflowjs

作者:Aral Roca

#更多程式相關知識,請造訪:程式設計課程! !

以上是怎麼用 TensorFlow.js 建立基本的 AI 模型?的詳細內容。更多資訊請關注PHP中文網其他相關文章!

陳述:
本文轉載於:segmentfault.com。如有侵權,請聯絡admin@php.cn刪除