搜索
首页后端开发Python教程为什么现实世界的机器学习需要分布式计算

Why You Need Distributed Computing for Real-World Machine Learning

PySpark 如何帮助您像专业人士一样处理庞大的数据集

PyTorch 和 TensorFlow 等机器学习框架非常适合构建模型。但现实是,当涉及到现实世界的项目时(处理巨大的数据集),您需要的不仅仅是一个好的模型。您需要一种有效处理和管理所有数据的方法。这就是像 PySpark 这样的分布式计算可以拯救世界的地方。

让我们来分析一下为什么在现实世界的机器学习中处理大数据意味着超越 PyTorch 和 TensorFlow,以及 PySpark 如何帮助您实现这一目标。
真正的问题:大数据
您在网上看到的大多数机器学习示例都使用小型、易于管理的数据集。您可以将整个事情放入内存中,进行尝试,并在几分钟内训练模型。但在现实场景中(例如信用卡欺诈检测、推荐系统或财务预测),您要处理数百万甚至数十亿行。突然间,您的笔记本电脑或服务器无法处理它。

如果您尝试将所有数据一次性加载到 PyTorch 或 TensorFlow 中,事情就会崩溃。这些框架是为模型训练而设计的,而不是为了有效处理巨大的数据集。这就是分布式计算变得至关重要的地方。
为什么 PyTorch 和 TensorFlow 还不够
PyTorch 和 TensorFlow 非常适合构建和优化模型,但在处理大规模数据任务时却表现不佳。两个主要问题:

  • 内存过载:他们在训练之前将整个数据集加载到内存中。这适用于小型数据集,但当您拥有 TB 级的数据时,游戏就结束了。
  • 无分布式数据处理:PyTorch 和 TensorFlow 并不是为处理分布式数据处理而构建的。如果您有大量数据分布在多台机器上,那么它们并没有真正的帮助。

这就是 PySpark 的闪光点。它旨在处理分布式数据,在多台机器上高效处理数据,同时处理大量数据集,而不会导致系统崩溃。

真实示例:使用 PySpark 检测信用卡欺诈
让我们深入研究一个例子。假设您正在开发使用信用卡交易数据的欺诈检测系统。在本例中,我们将使用 Kaggle 的流行数据集。它包含超过 284,000 笔交易,其中不到 1% 是欺诈交易。

第 1 步:在 Google Colab 中设置 PySpark
为此,我们将使用 Google Colab,因为它允许我们以最少的设置运行 PySpark。

!pip install pyspark

接下来,导入必要的库并启动 Spark 会话。

import os
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, sum, udf
from pyspark.ml.feature import VectorAssembler, StringIndexer, MinMaxScaler
from pyspark.ml.classification import RandomForestClassifier, GBTClassifier
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator
from pyspark.ml.linalg import Vectors
import numpy as np
from pyspark.sql.types import FloatType

启动 pyspark 会话

spark = SparkSession.builder \
    .appName("FraudDetectionImproved") \
    .master("local[*]") \
    .config("spark.executorEnv.PYTHONHASHSEED", "0") \
    .getOrCreate()

第 2 步:加载和准备数据

data = spark.read.csv('creditcard.csv', header=True, inferSchema=True)
data = data.orderBy("Time")  # Ensure data is sorted by time
data.show(5)
data.describe().show()
# Check for missing values in each column
data.select([sum(col(c).isNull().cast("int")).alias(c) for c in data.columns]).show()

# Prepare the feature columns
feature_columns = data.columns
feature_columns.remove("Class")  # Removing "Class" column as it is our label

# Assemble features into a single vector
assembler = VectorAssembler(inputCols=feature_columns, outputCol="features")
data = assembler.transform(data)
data.select("features", "Class").show(5)

# Split data into train (60%), test (20%), and unseen (20%)
train_data, temp_data = data.randomSplit([0.6, 0.4], seed=42)
test_data, unseen_data = temp_data.randomSplit([0.5, 0.5], seed=42)

# Print class distribution in each dataset
print("Train Data:")
train_data.groupBy("Class").count().show()

print("Test and parameter optimisation Data:")
test_data.groupBy("Class").count().show()

print("Unseen Data:")
unseen_data.groupBy("Class").count().show()

第 3 步:初始化模型

# Initialize RandomForestClassifier
rf = RandomForestClassifier(labelCol="Class", featuresCol="features", probabilityCol="probability")

# Create ParamGrid for Cross Validation
paramGrid = ParamGridBuilder() \
    .addGrid(rf.numTrees, [10, 20 ]) \
    .addGrid(rf.maxDepth, [5, 10]) \
    .build()

# Create 5-fold CrossValidator
crossval = CrossValidator(estimator=rf,
                          estimatorParamMaps=paramGrid,
                          evaluator=BinaryClassificationEvaluator(labelCol="Class", metricName="areaUnderROC"),
                          numFolds=5)

第 4 步:拟合、运行交叉验证,并选择最佳参数集

# Run cross-validation, and choose the best set of parameters
rf_model = crossval.fit(train_data)

# Make predictions on test data
predictions_rf = rf_model.transform(test_data)

# Evaluate Random Forest Model
binary_evaluator = BinaryClassificationEvaluator(labelCol="Class", rawPredictionCol="rawPrediction", metricName="areaUnderROC")
pr_evaluator = BinaryClassificationEvaluator(labelCol="Class", rawPredictionCol="rawPrediction", metricName="areaUnderPR")

auc_rf = binary_evaluator.evaluate(predictions_rf)
auprc_rf = pr_evaluator.evaluate(predictions_rf)
print(f"Random Forest - AUC: {auc_rf:.4f}, AUPRC: {auprc_rf:.4f}")

# UDF to extract positive probability from probability vector
extract_prob = udf(lambda prob: float(prob[1]), FloatType())
predictions_rf = predictions_rf.withColumn("positive_probability", extract_prob(col("probability")))

第 5 步计算精确率、召回率和 F1 分数的函数

# Function to calculate precision, recall, and F1-score
def calculate_metrics(predictions):
    tp = predictions.filter((col("Class") == 1) & (col("prediction") == 1)).count()
    fp = predictions.filter((col("Class") == 0) & (col("prediction") == 1)).count()
    fn = predictions.filter((col("Class") == 1) & (col("prediction") == 0)).count()

    precision = tp / (tp + fp) if (tp + fp) != 0 else 0
    recall = tp / (tp + fn) if (tp + fn) != 0 else 0
    f1_score = (2 * precision * recall) / (precision + recall) if (precision + recall) != 0 else 0

    return precision, recall, f1_score

第 6 步:找到模型的最佳阈值

# Find the best threshold for the model
best_threshold = 0.5
best_f1 = 0
for threshold in np.arange(0.1, 0.9, 0.1):
    thresholded_predictions = predictions_rf.withColumn("prediction", (col("positive_probability") > threshold).cast("double"))
    precision, recall, f1 = calculate_metrics(thresholded_predictions)

    if f1 > best_f1:
        best_f1 = f1
        best_threshold = threshold

print(f"Best threshold: {best_threshold}, Best F1-score: {best_f1:.4f}")

第七步:评估未见过的数据

# Evaluate on unseen data
predictions_unseen = rf_model.transform(unseen_data)
auc_unseen = binary_evaluator.evaluate(predictions_unseen)
print(f"Unseen Data - AUC: {auc_unseen:.4f}")

precision, recall, f1 = calculate_metrics(predictions_unseen)
print(f"Unseen Data - Precision: {precision:.4f}, Recall: {recall:.4f}, F1-score: {f1:.4f}")

area_under_roc = binary_evaluator.evaluate(predictions_unseen)
area_under_pr = pr_evaluator.evaluate(predictions_unseen)
print(f"Unseen Data - AUC: {area_under_roc:.4f}, AUPRC: {area_under_pr:.4f}")

结果

Best threshold: 0.30000000000000004, Best F1-score: 0.9062
Unseen Data - AUC: 0.9384
Unseen Data - Precision: 0.9655, Recall: 0.7568, F1-score: 0.8485
Unseen Data - AUC: 0.9423, AUPRC: 0.8618

然后您可以保存此模型(几KB)并在 pyspark 管道中的任何地方使用它

rf_model.save()

这就是 PySpark 在处理现实机器学习任务中的大型数据集时产生巨大差异的原因:

轻松扩展:PySpark 可以跨集群分配任务,让您能够处理 TB 级的数据,而不会耗尽内存。
动态数据处理:PySpark 不需要将整个数据集加载到内存中。它根据需要处理数据,这使得它更加高效。
更快的模型训练:借助分布式计算,您可以通过在多台机器上分配计算工作负载来更快地训练模型。
最后的想法
PyTorch 和 TensorFlow 是构建机器学习模型的绝佳工具,但对于现实世界的大规模任务,您需要更多工具。使用 PySpark 进行分布式计算可让您高效处理庞大数据集、实时处理数据并扩展机器学习管道。

因此,下次您处理海量数据时(无论是欺诈检测、推荐系统还是财务分析),请考虑使用 PySpark 将您的项目提升到新的水平。

有关完整的代码和结果,请查看此笔记本。 :
https://colab.research.google.com/drive/1W9naxNZirirLRodSEnHAUWevYd5LH8D4?authuser=5#scrollTo=odmodmqKcY23

__

我是 Swapnil,请随意留下您的评论、结果和想法,或者联系我 - swapnil@nooffice.no 获取数据、软件开发工作和工作

以上是为什么现实世界的机器学习需要分布式计算的详细内容。更多信息请关注PHP中文网其他相关文章!

声明
本文内容由网友自发贡献,版权归原作者所有,本站不承担相应法律责任。如您发现有涉嫌抄袭侵权的内容,请联系admin@php.cn
如何解决Linux终端中查看Python版本时遇到的权限问题?如何解决Linux终端中查看Python版本时遇到的权限问题?Apr 01, 2025 pm 05:09 PM

Linux终端中查看Python版本时遇到权限问题的解决方法当你在Linux终端中尝试查看Python的版本时,输入python...

我如何使用美丽的汤来解析HTML?我如何使用美丽的汤来解析HTML?Mar 10, 2025 pm 06:54 PM

本文解释了如何使用美丽的汤库来解析html。 它详细介绍了常见方法,例如find(),find_all(),select()和get_text(),以用于数据提取,处理不同的HTML结构和错误以及替代方案(SEL)

如何使用TensorFlow或Pytorch进行深度学习?如何使用TensorFlow或Pytorch进行深度学习?Mar 10, 2025 pm 06:52 PM

本文比较了Tensorflow和Pytorch的深度学习。 它详细介绍了所涉及的步骤:数据准备,模型构建,培训,评估和部署。 框架之间的关键差异,特别是关于计算刻度的

如何使用Python创建命令行接口(CLI)?如何使用Python创建命令行接口(CLI)?Mar 10, 2025 pm 06:48 PM

本文指导Python开发人员构建命令行界面(CLIS)。 它使用Typer,Click和ArgParse等库详细介绍,强调输入/输出处理,并促进用户友好的设计模式,以提高CLI可用性。

在Python中如何高效地将一个DataFrame的整列复制到另一个结构不同的DataFrame中?在Python中如何高效地将一个DataFrame的整列复制到另一个结构不同的DataFrame中?Apr 01, 2025 pm 11:15 PM

在使用Python的pandas库时,如何在两个结构不同的DataFrame之间进行整列复制是一个常见的问题。假设我们有两个Dat...

哪些流行的Python库及其用途?哪些流行的Python库及其用途?Mar 21, 2025 pm 06:46 PM

本文讨论了诸如Numpy,Pandas,Matplotlib,Scikit-Learn,Tensorflow,Tensorflow,Django,Blask和请求等流行的Python库,并详细介绍了它们在科学计算,数据分析,可视化,机器学习,网络开发和H中的用途

解释Python中虚拟环境的目的。解释Python中虚拟环境的目的。Mar 19, 2025 pm 02:27 PM

文章讨论了虚拟环境在Python中的作用,重点是管理项目依赖性并避免冲突。它详细介绍了他们在改善项目管理和减少依赖问题方面的创建,激活和利益。

什么是正则表达式?什么是正则表达式?Mar 20, 2025 pm 06:25 PM

正则表达式是在编程中进行模式匹配和文本操作的强大工具,从而提高了各种应用程序的文本处理效率。

See all articles

热AI工具

Undresser.AI Undress

Undresser.AI Undress

人工智能驱动的应用程序,用于创建逼真的裸体照片

AI Clothes Remover

AI Clothes Remover

用于从照片中去除衣服的在线人工智能工具。

Undress AI Tool

Undress AI Tool

免费脱衣服图片

Clothoff.io

Clothoff.io

AI脱衣机

AI Hentai Generator

AI Hentai Generator

免费生成ai无尽的。

热门文章

R.E.P.O.能量晶体解释及其做什么(黄色晶体)
3 周前By尊渡假赌尊渡假赌尊渡假赌
R.E.P.O.最佳图形设置
3 周前By尊渡假赌尊渡假赌尊渡假赌
R.E.P.O.如果您听不到任何人,如何修复音频
3 周前By尊渡假赌尊渡假赌尊渡假赌

热工具

ZendStudio 13.5.1 Mac

ZendStudio 13.5.1 Mac

功能强大的PHP集成开发环境

EditPlus 中文破解版

EditPlus 中文破解版

体积小,语法高亮,不支持代码提示功能

螳螂BT

螳螂BT

Mantis是一个易于部署的基于Web的缺陷跟踪工具,用于帮助产品缺陷跟踪。它需要PHP、MySQL和一个Web服务器。请查看我们的演示和托管服务。

SublimeText3 Linux新版

SublimeText3 Linux新版

SublimeText3 Linux最新版

mPDF

mPDF

mPDF是一个PHP库,可以从UTF-8编码的HTML生成PDF文件。原作者Ian Back编写mPDF以从他的网站上“即时”输出PDF文件,并处理不同的语言。与原始脚本如HTML2FPDF相比,它的速度较慢,并且在使用Unicode字体时生成的文件较大,但支持CSS样式等,并进行了大量增强。支持几乎所有语言,包括RTL(阿拉伯语和希伯来语)和CJK(中日韩)。支持嵌套的块级元素(如P、DIV),