찾다
백엔드 개발파이썬 튜토리얼TensorFlow를 사용하여 다중 클래스 지원 벡터 머신을 구현하기 위한 샘플 코드

이 기사에서는 TensorFlow를 사용하여 다중 클래스 지원 벡터 머신을 구현하기 위한 샘플 코드를 주로 소개합니다. 함께 구경해보세요

이 글에서는 붓꽃 데이터 세트를 학습하여 세 가지 종류의 꽃을 분류하는 다중 클래스 지원 벡터 머신 분류기를 자세히 보여드리겠습니다.

SVM 알고리즘은 원래 이진 분류 문제를 위해 설계되었지만 일부 전략을 통해 다중 클래스 분류에도 사용할 수 있습니다. 두 가지 주요 전략은 1대 전체(1대 전체) 접근 방식과 1대 1(1대 1) 접근 방식입니다.

일대일 방법은 두 가지 유형의 샘플 사이에 이진 분류기를 설계하고 생성한 후 가장 많은 표를 얻은 카테고리가 알려지지 않은 샘플의 예측 카테고리가 되는 것입니다. 하지만 카테고리가 많으면(k개 카테고리) k를 생성해야 합니다! /(k-2)! 2! 분류기의 경우 계산 비용은 여전히 ​​상당히 높습니다.

다중 클래스 분류자를 구현하는 또 다른 방법은 각 클래스에 대한 분류자를 생성하는 일대다입니다. 마지막 예측 클래스는 SVM 간격이 가장 큰 클래스입니다. 이 문서에서는 이 방법을 구현합니다.

붓꽃 데이터세트를 로드하고 가우시안 커널 기능을 갖춘 비선형 다중 클래스 SVM 모델을 사용하겠습니다. 붓꽃 데이터 세트에는 산 붓꽃, 변경 가능한 붓꽃, 버지니아 붓꽃(I.setosa, I.virginica 및 I.versicolor)의 세 가지 범주가 포함되어 있습니다. 예측을 위해 세 가지 가우스 커널 함수 SVM을 생성합니다.

# Multi-class (Nonlinear) SVM Example
#----------------------------------
#
# This function wll illustrate how to
# implement the gaussian kernel with
# multiple classes on the iris dataset.
#
# Gaussian Kernel:
# K(x1, x2) = exp(-gamma * abs(x1 - x2)^2)
#
# X : (Sepal Length, Petal Width)
# Y: (I. setosa, I. virginica, I. versicolor) (3 classes)
#
# Basic idea: introduce an extra dimension to do
# one vs all classification.
#
# The prediction of a point will be the category with
# the largest margin or distance to boundary.

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from sklearn import datasets
from tensorflow.python.framework import ops
ops.reset_default_graph()

# Create graph
sess = tf.Session()

# Load the data
# 加载iris数据集并为每类分离目标值。
# 因为我们想绘制结果图,所以只使用花萼长度和花瓣宽度两个特征。
# 为了便于绘图,也会分离x值和y值
# iris.data = [(Sepal Length, Sepal Width, Petal Length, Petal Width)]
iris = datasets.load_iris()
x_vals = np.array([[x[0], x[3]] for x in iris.data])
y_vals1 = np.array([1 if y==0 else -1 for y in iris.target])
y_vals2 = np.array([1 if y==1 else -1 for y in iris.target])
y_vals3 = np.array([1 if y==2 else -1 for y in iris.target])
y_vals = np.array([y_vals1, y_vals2, y_vals3])
class1_x = [x[0] for i,x in enumerate(x_vals) if iris.target[i]==0]
class1_y = [x[1] for i,x in enumerate(x_vals) if iris.target[i]==0]
class2_x = [x[0] for i,x in enumerate(x_vals) if iris.target[i]==1]
class2_y = [x[1] for i,x in enumerate(x_vals) if iris.target[i]==1]
class3_x = [x[0] for i,x in enumerate(x_vals) if iris.target[i]==2]
class3_y = [x[1] for i,x in enumerate(x_vals) if iris.target[i]==2]

# Declare batch size
batch_size = 50

# Initialize placeholders
# 数据集的维度在变化,从单类目标分类到三类目标分类。
# 我们将利用矩阵传播和reshape技术一次性计算所有的三类SVM。
# 注意,由于一次性计算所有分类,
# y_target占位符的维度是[3,None],模型变量b初始化大小为[3,batch_size]
x_data = tf.placeholder(shape=[None, 2], dtype=tf.float32)
y_target = tf.placeholder(shape=[3, None], dtype=tf.float32)
prediction_grid = tf.placeholder(shape=[None, 2], dtype=tf.float32)

# Create variables for svm
b = tf.Variable(tf.random_normal(shape=[3,batch_size]))

# Gaussian (RBF) kernel 核函数只依赖x_data
gamma = tf.constant(-10.0)
dist = tf.reduce_sum(tf.square(x_data), 1)
dist = tf.reshape(dist, [-1,1])
sq_dists = tf.multiply(2., tf.matmul(x_data, tf.transpose(x_data)))
my_kernel = tf.exp(tf.multiply(gamma, tf.abs(sq_dists)))

# Declare function to do reshape/batch multiplication
# 最大的变化是批量矩阵乘法。
# 最终的结果是三维矩阵,并且需要传播矩阵乘法。
# 所以数据矩阵和目标矩阵需要预处理,比如xT·x操作需额外增加一个维度。
# 这里创建一个函数来扩展矩阵维度,然后进行矩阵转置,
# 接着调用TensorFlow的tf.batch_matmul()函数
def reshape_matmul(mat):
  v1 = tf.expand_dims(mat, 1)
  v2 = tf.reshape(v1, [3, batch_size, 1])
  return(tf.matmul(v2, v1))

# Compute SVM Model 计算对偶损失函数
first_term = tf.reduce_sum(b)
b_vec_cross = tf.matmul(tf.transpose(b), b)
y_target_cross = reshape_matmul(y_target)

second_term = tf.reduce_sum(tf.multiply(my_kernel, tf.multiply(b_vec_cross, y_target_cross)),[1,2])
loss = tf.reduce_sum(tf.negative(tf.subtract(first_term, second_term)))

# Gaussian (RBF) prediction kernel
# 现在创建预测核函数。
# 要当心reduce_sum()函数,这里我们并不想聚合三个SVM预测,
# 所以需要通过第二个参数告诉TensorFlow求和哪几个
rA = tf.reshape(tf.reduce_sum(tf.square(x_data), 1),[-1,1])
rB = tf.reshape(tf.reduce_sum(tf.square(prediction_grid), 1),[-1,1])
pred_sq_dist = tf.add(tf.subtract(rA, tf.multiply(2., tf.matmul(x_data, tf.transpose(prediction_grid)))), tf.transpose(rB))
pred_kernel = tf.exp(tf.multiply(gamma, tf.abs(pred_sq_dist)))

# 实现预测核函数后,我们创建预测函数。
# 与二类不同的是,不再对模型输出进行sign()运算。
# 因为这里实现的是一对多方法,所以预测值是分类器有最大返回值的类别。
# 使用TensorFlow的内建函数argmax()来实现该功能
prediction_output = tf.matmul(tf.multiply(y_target,b), pred_kernel)
prediction = tf.arg_max(prediction_output-tf.expand_dims(tf.reduce_mean(prediction_output,1), 1), 0)
accuracy = tf.reduce_mean(tf.cast(tf.equal(prediction, tf.argmax(y_target,0)), tf.float32))

# Declare optimizer
my_opt = tf.train.GradientDescentOptimizer(0.01)
train_step = my_opt.minimize(loss)

# Initialize variables
init = tf.global_variables_initializer()
sess.run(init)

# Training loop
loss_vec = []
batch_accuracy = []
for i in range(100):
  rand_index = np.random.choice(len(x_vals), size=batch_size)
  rand_x = x_vals[rand_index]
  rand_y = y_vals[:,rand_index]
  sess.run(train_step, feed_dict={x_data: rand_x, y_target: rand_y})

  temp_loss = sess.run(loss, feed_dict={x_data: rand_x, y_target: rand_y})
  loss_vec.append(temp_loss)

  acc_temp = sess.run(accuracy, feed_dict={x_data: rand_x,
                       y_target: rand_y,
                       prediction_grid:rand_x})
  batch_accuracy.append(acc_temp)

  if (i+1)%25==0:
    print('Step #' + str(i+1))
    print('Loss = ' + str(temp_loss))

# 创建数据点的预测网格,运行预测函数
x_min, x_max = x_vals[:, 0].min() - 1, x_vals[:, 0].max() + 1
y_min, y_max = x_vals[:, 1].min() - 1, x_vals[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.02),
           np.arange(y_min, y_max, 0.02))
grid_points = np.c_[xx.ravel(), yy.ravel()]
grid_predictions = sess.run(prediction, feed_dict={x_data: rand_x,
                          y_target: rand_y,
                          prediction_grid: grid_points})
grid_predictions = grid_predictions.reshape(xx.shape)

# Plot points and grid
plt.contourf(xx, yy, grid_predictions, cmap=plt.cm.Paired, alpha=0.8)
plt.plot(class1_x, class1_y, 'ro', label='I. setosa')
plt.plot(class2_x, class2_y, 'kx', label='I. versicolor')
plt.plot(class3_x, class3_y, 'gv', label='I. virginica')
plt.title('Gaussian SVM Results on Iris Data')
plt.xlabel('Pedal Length')
plt.ylabel('Sepal Width')
plt.legend(loc='lower right')
plt.ylim([-0.5, 3.0])
plt.xlim([3.5, 8.5])
plt.show()

# Plot batch accuracy
plt.plot(batch_accuracy, 'k-', label='Accuracy')
plt.title('Batch Accuracy')
plt.xlabel('Generation')
plt.ylabel('Accuracy')
plt.legend(loc='lower right')
plt.show()

# Plot loss over time
plt.plot(loss_vec, 'k-')
plt.title('Loss per Generation')
plt.xlabel('Generation')
plt.ylabel('Loss')
plt.show()

출력:

업데이트 지침:
대신 `argmax` 사용
Step #25
Loss = -313.391
Step #50
Loss = -650.891
Step #75
Los s = -988.39 STEP #100 #LOSS = -1325.89 Multi -Classification (세 가지 범주) I.Setosa의 비선형 가우스 SVM 모델의 결과, 여기서 감마 값은 10


입니다. SVM 알고리즘을 변경하여 세 가지 유형의 SVM 모델을 동시에 최적화하는 데 중점을 둡니다. 모델 매개변수 b는 차원 하나를 추가하여 세 모델에 대해 계산됩니다. TensorFlow의 내장 함수를 사용하면 알고리즘이 여러 유형의 유사한 알고리즘으로 쉽게 확장될 수 있음을 알 수 있습니다.

관련 권장 사항:

비선형 서포트 벡터 머신의 TensorFlow 구현 방법


위 내용은 TensorFlow를 사용하여 다중 클래스 지원 벡터 머신을 구현하기 위한 샘플 코드의 상세 내용입니다. 자세한 내용은 PHP 중국어 웹사이트의 기타 관련 기사를 참조하세요!

성명
본 글의 내용은 네티즌들의 자발적인 기여로 작성되었으며, 저작권은 원저작자에게 있습니다. 본 사이트는 이에 상응하는 법적 책임을 지지 않습니다. 표절이나 침해가 의심되는 콘텐츠를 발견한 경우 admin@php.cn으로 문의하세요.
Python 학습 : 2 시간의 일일 연구가 충분합니까?Python 학습 : 2 시간의 일일 연구가 충분합니까?Apr 18, 2025 am 12:22 AM

하루에 2 시간 동안 파이썬을 배우는 것으로 충분합니까? 목표와 학습 방법에 따라 다릅니다. 1) 명확한 학습 계획을 개발, 2) 적절한 학습 자원 및 방법을 선택하고 3) 실습 연습 및 검토 및 통합 연습 및 검토 및 통합,이 기간 동안 Python의 기본 지식과 고급 기능을 점차적으로 마스터 할 수 있습니다.

웹 개발을위한 파이썬 : 주요 응용 프로그램웹 개발을위한 파이썬 : 주요 응용 프로그램Apr 18, 2025 am 12:20 AM

웹 개발에서 Python의 주요 응용 프로그램에는 Django 및 Flask 프레임 워크 사용, API 개발, 데이터 분석 및 시각화, 머신 러닝 및 AI 및 성능 최적화가 포함됩니다. 1. Django 및 Flask 프레임 워크 : Django는 복잡한 응용 분야의 빠른 개발에 적합하며 플라스크는 소형 또는 고도로 맞춤형 프로젝트에 적합합니다. 2. API 개발 : Flask 또는 DjangorestFramework를 사용하여 RESTFULAPI를 구축하십시오. 3. 데이터 분석 및 시각화 : Python을 사용하여 데이터를 처리하고 웹 인터페이스를 통해 표시합니다. 4. 머신 러닝 및 AI : 파이썬은 지능형 웹 애플리케이션을 구축하는 데 사용됩니다. 5. 성능 최적화 : 비동기 프로그래밍, 캐싱 및 코드를 통해 최적화

Python vs. C : 성능과 효율성 탐색Python vs. C : 성능과 효율성 탐색Apr 18, 2025 am 12:20 AM

Python은 개발 효율에서 C보다 낫지 만 C는 실행 성능이 높습니다. 1. Python의 간결한 구문 및 풍부한 라이브러리는 개발 효율성을 향상시킵니다. 2.C의 컴파일 유형 특성 및 하드웨어 제어는 실행 성능을 향상시킵니다. 선택할 때는 프로젝트 요구에 따라 개발 속도 및 실행 효율성을 평가해야합니다.

Python in Action : 실제 예제Python in Action : 실제 예제Apr 18, 2025 am 12:18 AM

Python의 실제 응용 프로그램에는 데이터 분석, 웹 개발, 인공 지능 및 자동화가 포함됩니다. 1) 데이터 분석에서 Python은 Pandas 및 Matplotlib를 사용하여 데이터를 처리하고 시각화합니다. 2) 웹 개발에서 Django 및 Flask 프레임 워크는 웹 응용 프로그램 생성을 단순화합니다. 3) 인공 지능 분야에서 Tensorflow와 Pytorch는 모델을 구축하고 훈련시키는 데 사용됩니다. 4) 자동화 측면에서 파이썬 스크립트는 파일 복사와 같은 작업에 사용할 수 있습니다.

Python의 주요 용도 : 포괄적 인 개요Python의 주요 용도 : 포괄적 인 개요Apr 18, 2025 am 12:18 AM

Python은 데이터 과학, 웹 개발 및 자동화 스크립팅 필드에 널리 사용됩니다. 1) 데이터 과학에서 Python은 Numpy 및 Pandas와 같은 라이브러리를 통해 데이터 처리 및 분석을 단순화합니다. 2) 웹 개발에서 Django 및 Flask 프레임 워크를 통해 개발자는 응용 프로그램을 신속하게 구축 할 수 있습니다. 3) 자동 스크립트에서 Python의 단순성과 표준 라이브러리가 이상적입니다.

파이썬의 주요 목적 : 유연성과 사용 편의성파이썬의 주요 목적 : 유연성과 사용 편의성Apr 17, 2025 am 12:14 AM

Python의 유연성은 다중 파리가 지원 및 동적 유형 시스템에 반영되며, 사용 편의성은 간단한 구문 및 풍부한 표준 라이브러리에서 나옵니다. 유연성 : 객체 지향, 기능 및 절차 프로그래밍을 지원하며 동적 유형 시스템은 개발 효율성을 향상시킵니다. 2. 사용 편의성 : 문법은 자연 언어에 가깝고 표준 라이브러리는 광범위한 기능을 다루며 개발 프로세스를 단순화합니다.

파이썬 : 다목적 프로그래밍의 힘파이썬 : 다목적 프로그래밍의 힘Apr 17, 2025 am 12:09 AM

Python은 초보자부터 고급 개발자에 이르기까지 모든 요구에 적합한 단순성과 힘에 호의적입니다. 다목적 성은 다음과 같이 반영됩니다. 1) 배우고 사용하기 쉽고 간단한 구문; 2) Numpy, Pandas 등과 같은 풍부한 라이브러리 및 프레임 워크; 3) 다양한 운영 체제에서 실행할 수있는 크로스 플랫폼 지원; 4) 작업 효율성을 향상시키기위한 스크립팅 및 자동화 작업에 적합합니다.

하루 2 시간 안에 파이썬 학습 : 실용 가이드하루 2 시간 안에 파이썬 학습 : 실용 가이드Apr 17, 2025 am 12:05 AM

예, 하루에 2 시간 후에 파이썬을 배우십시오. 1. 합리적인 학습 계획 개발, 2. 올바른 학습 자원을 선택하십시오. 3. 실습을 통해 학습 된 지식을 통합하십시오. 이 단계는 짧은 시간 안에 Python을 마스터하는 데 도움이 될 수 있습니다.

See all articles

핫 AI 도구

Undresser.AI Undress

Undresser.AI Undress

사실적인 누드 사진을 만들기 위한 AI 기반 앱

AI Clothes Remover

AI Clothes Remover

사진에서 옷을 제거하는 온라인 AI 도구입니다.

Undress AI Tool

Undress AI Tool

무료로 이미지를 벗다

Clothoff.io

Clothoff.io

AI 옷 제거제

AI Hentai Generator

AI Hentai Generator

AI Hentai를 무료로 생성하십시오.

뜨거운 도구

WebStorm Mac 버전

WebStorm Mac 버전

유용한 JavaScript 개발 도구

Atom Editor Mac 버전 다운로드

Atom Editor Mac 버전 다운로드

가장 인기 있는 오픈 소스 편집기

DVWA

DVWA

DVWA(Damn Vulnerable Web App)는 매우 취약한 PHP/MySQL 웹 애플리케이션입니다. 주요 목표는 보안 전문가가 법적 환경에서 자신의 기술과 도구를 테스트하고, 웹 개발자가 웹 응용 프로그램 보안 프로세스를 더 잘 이해할 수 있도록 돕고, 교사/학생이 교실 환경 웹 응용 프로그램에서 가르치고 배울 수 있도록 돕는 것입니다. 보안. DVWA의 목표는 다양한 난이도의 간단하고 간단한 인터페이스를 통해 가장 일반적인 웹 취약점 중 일부를 연습하는 것입니다. 이 소프트웨어는

SublimeText3 영어 버전

SublimeText3 영어 버전

권장 사항: Win 버전, 코드 프롬프트 지원!

SublimeText3 Mac 버전

SublimeText3 Mac 버전

신 수준의 코드 편집 소프트웨어(SublimeText3)