Home >Technology peripherals >AI >Use Scikit-Learn to quickly master machine learning prediction methods

Use Scikit-Learn to quickly master machine learning prediction methods

王林
王林forward
2023-05-27 14:26:031479browse

In this article, we will discuss the differences between prediction functions and their uses.

In machine learning, the predict and predict_proba, predict_log_proba and decision_function methods are all used to make predictions based on the trained model.

predict method

Use the predict method to predict binary classification or multivariate classification and output prediction labels. For example, if you have trained a logistic regression model to predict whether a customer will buy a product, you can use the predict method to predict whether a new customer will buy the product.

We will use the breast cancer dataset from scikit-learn. This dataset contains tumor observations and corresponding labels of whether the tumor is malignant or benign.

import numpy as npfrom sklearn.svm import SVCfrom sklearn.preprocessing import StandardScalerfrom sklearn.pipeline import make_pipelineimport matplotlib.pyplot as pltfrom sklearn.datasets import load_breast_cancer# 加载数据集dataset = load_breast_cancer(as_frame=True)# 创建特征和目标X = dataset['data']y = dataset['target']# 将数据集分割成训练集和测试集from sklearn.model_selection import train_test_splitX_train, X_test, y_train, y_test = train_test_split(X, y , test_size=0.25, random_state=0)# 我们创建一个简单的管道来规范数据并使用`SVC`分类器训练模型svc_clf = make_pipeline(StandardScaler(),SVC(max_iter=1000, probability=True))svc_clf.fit(X_train, y_train)
# 我们正在预测X_test的第一个条目print(svc_clf.predict(X_test[:1]))
# 预测X_test的第一个条目属于哪一类[0]

predict_proba method

Use the predict_proba function to make a probability prediction for each category and return the possible probability estimate of each category label. In binary or multivariate classification problems, this approach is commonly used to determine the probability of each possible outcome. For example, if you have trained a model to classify images of animals into cats, dogs, and horses, you can use the predict_proba method to obtain probability estimates for each category label.

print(svc_clf.predict_proba(X_test[:1]))
[[0.99848307 0.00151693]]

predict_log_proba method

The predict_log_proba method is similar to predict_proba, but it returns the logarithm of the probability estimate instead of the raw probability. This is very useful for dealing with very small or very large probability values, because it can avoid numerical underflow or overflow problems.

print(svc_clf.predict_log_proba(X_test[:1]))
rrree

decision_function method

Linear binary classification models can utilize the decision_function method. It generates a score for each input data point, which can be used to infer its corresponding class label. Thresholds that classify data points as positive or negative can be set based on application or domain knowledge.

[[-1.51808474e-03 -6.49106473e+00]]
print(svc_clf.decision_function(X_test[:1]))

Summary

  • Use predict for binary or multivariate classification problems when you want to get the predicted class label of the input data.
  • Use predict_proba for binary or multivariate classification problems when you want to obtain a probability estimate for each possible class label.
  • Use predict_log_proba when you need to deal with very small or very large probability values, or when you want to avoid numerical underflow or overflow problems.
  • Use decision_function to handle binary classification problems with linear models when you want to get the score for each input data point.

Note: Some classifiers' prediction methods may be incomplete or require additional parameters to access the function. For example: SVC needs to set the probability parameter to True to use probability prediction.

The above is the detailed content of Use Scikit-Learn to quickly master machine learning prediction methods. For more information, please follow other related articles on the PHP Chinese website!

Statement:
This article is reproduced at:51cto.com. If there is any infringement, please contact admin@php.cn delete