Rumah  >  Artikel  >  pembangunan bahagian belakang  >  看看pyhton的sklearn机器学习算法

看看pyhton的sklearn机器学习算法

coldplay.xixi
coldplay.xixike hadapan
2021-02-04 17:45:082396semak imbas

看看pyhton的sklearn机器学习算法

免费学习推荐:python视频教程

导入必要通用模块

import pandas as pdimport matplotlib.pyplot as pltimport osimport numpy as npimport copyimport reimport math

一 机器学习通用框架:以knn为例

#利用邻近点方式训练数据不太适用于高维数据from sklearn.model_selection import train_test_split#将数据分为测试集和训练集from sklearn.neighbors import KNeighborsClassifier#利用邻近点方式训练数据#1.读取数据data=pd.read_excel('数据/样本数据.xlsx')#2.将数据标准化from sklearn import preprocessingfor col in data.columns[2:]:#为了不破坏数据集中的离散变量,只将数值种类数高于10的连续变量标准化
       if len(set(data[col]))>10:
              data[col]=preprocessing.scale(data[col])#3.构造自变量和因变量并划分为训练集和测试集X=data[['month_income','education_outcome','relationship_outcome', 'entertainment_outcome','traffic_', 'express',
       'express_distance','satisfac', 'wifi_neghbor','wifi_relative', 'wifi_frend', 'internet']]y=data['wifi']X_train,X_test,y_train,y_test=train_test_split(X,y,test_size=0.3)#利用train_test_split进行将训练集和测试集进行分开,test_size占30%#4.模型拟合model=KNeighborsClassifier()#引入训练方法model.fit(X_train,y_train)#进行填充测试数据进行训练y_predict=model.predict(X_test)#利用测试集数据作出预测#通过修改判别概率标准修改预测结果proba=model.predict_proba(X_test)#返回基于各个测试集样本所预测的结果为0和为1的概率值#5.模型评价#(1)测试集样本数据拟合优度,model.score(X,y)model.score(X_test,y_test)#(2)构建混淆矩阵,判断预测精准程度"""
混淆矩阵中行代表真实值,列代表预测值
TN:实际为0预测为0的个数       FP:实际为0预测为1的个数
FN:实际为1预测为0的个数       TP:实际为1预测为1的个数

精准率precision=TP/(TP+FP)——被预测为1的样本的的预测正确率
召回率recall=TP/(TP+FN)——实际为1的样本的正确预测率
"""from sklearn.metrics import confusion_matrix
cfm=confusion_matrix(y_test, y_predict)plt.matshow(cfm,cmap=plt.cm.gray)#cmap参数为绘制矩阵的颜色集合,这里使用灰度plt.show()#(3)精准率和召回率from sklearn.metrics import precision_score,recall_score
precision_score(y_test, y_predict)# 精准率recall_score(y_test, y_predict)#召回率#(4)错误率矩阵row_sums = np.sum(cfm,axis=1)err_matrix = cfm/row_sums
np.fill_diagonal(err_matrix,0)#对err_matrix矩阵的对角线置0,因为这是预测正确的部分,不关心plt.matshow(err_matrix,cmap=plt.cm.gray)#亮度越高的地方代表错误率越高plt.show()

二 数据处理

#1.构造数据集from sklearn import datasets#引入数据集#n_samples为生成样本的数量,n_features为X中自变量的个数,n_targets为y中因变量的个数,bias表示使线性模型发生偏差的程度,X,y=datasets.make_regression(n_samples=100,n_features=1,n_targets=1,noise=1,bias=0.5,tail_strength=0.1)plt.figure(figsize=(12,12))plt.scatter(X,y)#2.读取数据data=pd.read_excel('数据/样本数据.xlsx')#3.将数据标准化——preprocessing.scale(data)from sklearn import preprocessing#为了不破坏数据集中的离散变量,只将数值种类数高于10的连续变量标准化for col in data.columns[2:]:
       if len(set(data[col]))>10:
              data[col]=preprocessing.scale(data[col])

三 回归

1.普通最小二乘线性回归

import numpy as npfrom sklearn.linear_model import LinearRegressionfrom sklearn.model_selection import train_test_split

X=data[['work', 'work_time', 'work_salary',
       'work_address', 'worker_number', 'month_income', 'total_area',
       'own_area', 'rend_area', 'out_area',
       'agricultal_income', 'things', 'wifi', 'internet_fee', 'cloth_outcome',
       'education_outcome', 'medcine_outcome', 'person_medicne_outcome',
       'relationship_outcome', 'food_outcome', 'entertainment_outcome',
       'agriculta_outcome', 'other_outcome', 'owe', 'owe_total', 'debt',
       'debt_way', 'distance_debt', 'distance_market', 'traffic_', 'express',
       'express_distance', 'exercise', 'satisfac', 'wifi_neghbor',
       'wifi_relative', 'wifi_frend', 'internet', 'medical_insurance']]y=data['total_income']model=LinearRegression().fit(X,y)#拟合模型model.score(X,y)#拟合优度model.coef_#查看拟合系数model.intercept_#查看拟合截距项model.predict(np.array(X.ix[25,:]).reshape(1,-1))#预测model.get_params()#得到模型的参数

2.逻辑回归Logit

from sklearn.linear_model import LogisticRegression#2.1数据处理X=data[['month_income', 'education_outcome','relationship_outcome', 'entertainment_outcome','traffic_', 'express',
       'express_distance','satisfac', 'wifi_neghbor','wifi_relative', 'wifi_frend', 'internet']]y=data['wifi']X_train,X_test,y_train,y_test=train_test_split(X,y,test_size=0.3)#利用train_test_split进行将训练集和测试集进行分开,test_size占30%#2.2模型拟合model = LogisticRegression()model.fit(X_train,y_train)model.score(X_test,y_test)#2.3模型预测y_predict = model.predict(X_test)#2.4通过调整判别分数标准,来调整判别结果decsion_scores = model.decision_function(X_test)#用于决定预测值取值的判别分数y_predict = decsion_scores>=5.0#将判别分数标准调整为5#2.5通过 精准率——召回率曲线图 寻找最优判别标准#由于随着判别标准的变化,精确率和召回率此消彼长,因此需要寻找一个最佳的判别标准使得精准率和召回率尽可能大from sklearn.metrics import precision_recall_curve
precisions,recalls,thresholds = precision_recall_curve(y_test,decsion_scores)#thresholds表示所有可能得判别标准,即判别分数最大与最小值之间的范围#由于precisions和recalls中比thresholds多了一个元素,因此要绘制曲线,先去掉这个元素plt.plot(thresholds,precisions[:-1])plt.plot(thresholds,recalls[:-1])plt.show()y_predict = decsion_scores>=2#根据上图显示,两线交于-0.3处,因此将判别分数标准调整为-0.3#2.6绘制ROC曲线:用于描述TPR和FPR之间的关系,ROC曲线围成的面积越大,说明模型越好"""TPR即是召回率_越大越好,FPR=(FP)/(TN+FP)_越小越好"""from sklearn.metrics import roc_curve
fprs,tprs,thresholds = roc_curve(y_test,decsion_scores)plt.plot(fprs,tprs)plt.show()#2.7绘制混淆矩阵from sklearn.metrics import confusion_matrix,precision_score,recall_score
cfm =confusion_matrix(y_test, y_predict)# 构建混淆矩阵并绘制混淆矩阵热力图plt.matshow(cfm,cmap=plt.cm.gray)#cmap参数为绘制矩阵的颜色集合,这里使用灰度plt.show()precision_score(y_test, y_predict)# 精准率recall_score(y_test, y_predict)#召回率

四 模型评价

#1.混淆矩阵,精准率和召回率from sklearn.metrics import confusion_matrix,precision_score,recall_score"""
混淆矩阵中行代表真实值,列代表预测值
TN:实际为0预测为0的个数       FP:实际为0预测为1的个数
FN:实际为1预测为0的个数       TP:实际为1预测为1的个数

精准率precision=TP/(TP+FP)——被预测为1的样本的的预测正确率
召回率recall=TP/(TP+FN)——实际为1的样本的正确预测率
"""cfm =confusion_matrix(y_test, y_predict)# 构建混淆矩阵并绘制混淆矩阵热力图plt.matshow(cfm,cmap=plt.cm.gray)#cmap参数为绘制矩阵的颜色集合,这里使用灰度plt.show()precision_score(y_test, y_predict)# 精准率recall_score(y_test, y_predict)#召回率#2.精准率和召回率作图:由于精准率和召回率此消彼长,应当选择适当的参数使二者同时尽可能的大#3.调和平均值"""精准率和召回率的调和平均值"""from sklearn.metrics import f1_score
f1_score(y_test,y_predict)#4.错误率矩阵row_sums = np.sum(cfm,axis=1)err_matrix = cfm/row_sums
np.fill_diagonal(err_matrix,0)#对err_matrix矩阵的对角线置0,因为这是预测正确的部分,不关心plt.matshow(err_matrix,cmap=plt.cm.gray)#亮度越高的地方代表错误率越高plt.show()

相关免费学习推荐:python教程(视频)

Atas ialah kandungan terperinci 看看pyhton的sklearn机器学习算法. Untuk maklumat lanjut, sila ikut artikel berkaitan lain di laman web China PHP!

Kenyataan:
Artikel ini dikembalikan pada:csdn.net. Jika ada pelanggaran, sila hubungi admin@php.cn Padam