


And How PySpark Can Help You Handle Huge Datasets Like a Pro
Machine learning frameworks like PyTorch and TensorFlow are awesome for building models. But the reality is, when it comes to real-world projects—where you’re dealing with gigantic datasets—you need more than just a good model. You need a way to efficiently process and manage all that data. That’s where distributed computing, like PySpark, comes in to save the day.
Let’s break down why handling big data in real-world machine learning means going beyond PyTorch and TensorFlow, and how PySpark helps you get there.
The Real Problem: Big Data
Most ML examples you see online use small, manageable datasets. You can fit the whole thing into memory, play around with it, and train a model in minutes. But in real-world scenarios—like credit card fraud detection, recommendation systems, or financial forecasts—you’re dealing with millions or even billions of rows. Suddenly, your laptop or server can’t handle it.
If you try loading all that data into PyTorch or TensorFlow at once, things will break. These frameworks are designed for model training, not for efficiently handling huge datasets. This is where distributed computing becomes crucial.
Why PyTorch and TensorFlow Aren’t Enough
PyTorch and TensorFlow are great for building and optimizing models, but they fall short when dealing with large-scale data tasks. Two major problems:
- Memory Overload: They load the entire dataset into memory before training. That works for small datasets, but when you’ve got terabytes of data, it’s game over.
- No Distributed Data Processing: PyTorch and TensorFlow aren’t built to handle distributed data processing. If you’ve got massive amounts of data spread across multiple machines, they don’t really help.
This is where PySpark shines. It’s designed to work with distributed data, processing it efficiently across multiple machines while handling massive datasets without crashing your system.
Real-World Example: Credit Card Fraud Detection with PySpark
Let’s dive into an example. Suppose you’re working on a fraud detection system using credit card transaction data. In this case, we’ll use a popular dataset from Kaggle. It contains over 284,000 transactions, and less than 1% of them are fraudulent.
Step 1: Set Up PySpark in Google Colab
We’ll use Google Colab for this because it lets us run PySpark with minimal setup.
!pip install pyspark
Next, import the necessary libraries and start a Spark session.
import os from pyspark.sql import SparkSession from pyspark.sql.functions import col, sum, udf from pyspark.ml.feature import VectorAssembler, StringIndexer, MinMaxScaler from pyspark.ml.classification import RandomForestClassifier, GBTClassifier from pyspark.ml.tuning import ParamGridBuilder, CrossValidator from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator from pyspark.ml.linalg import Vectors import numpy as np from pyspark.sql.types import FloatType
start a pyspark session
spark = SparkSession.builder \ .appName("FraudDetectionImproved") \ .master("local[*]") \ .config("spark.executorEnv.PYTHONHASHSEED", "0") \ .getOrCreate()
Step 2: Load and Prepare data
data = spark.read.csv('creditcard.csv', header=True, inferSchema=True) data = data.orderBy("Time") # Ensure data is sorted by time data.show(5) data.describe().show()
# Check for missing values in each column data.select([sum(col(c).isNull().cast("int")).alias(c) for c in data.columns]).show() # Prepare the feature columns feature_columns = data.columns feature_columns.remove("Class") # Removing "Class" column as it is our label # Assemble features into a single vector assembler = VectorAssembler(inputCols=feature_columns, outputCol="features") data = assembler.transform(data) data.select("features", "Class").show(5) # Split data into train (60%), test (20%), and unseen (20%) train_data, temp_data = data.randomSplit([0.6, 0.4], seed=42) test_data, unseen_data = temp_data.randomSplit([0.5, 0.5], seed=42) # Print class distribution in each dataset print("Train Data:") train_data.groupBy("Class").count().show() print("Test and parameter optimisation Data:") test_data.groupBy("Class").count().show() print("Unseen Data:") unseen_data.groupBy("Class").count().show()
Step 3: Initialise Model
# Initialize RandomForestClassifier rf = RandomForestClassifier(labelCol="Class", featuresCol="features", probabilityCol="probability") # Create ParamGrid for Cross Validation paramGrid = ParamGridBuilder() \ .addGrid(rf.numTrees, [10, 20 ]) \ .addGrid(rf.maxDepth, [5, 10]) \ .build() # Create 5-fold CrossValidator crossval = CrossValidator(estimator=rf, estimatorParamMaps=paramGrid, evaluator=BinaryClassificationEvaluator(labelCol="Class", metricName="areaUnderROC"), numFolds=5)
Step 4: Fit, Run cross-validation, and choose the best set of parameters
# Run cross-validation, and choose the best set of parameters rf_model = crossval.fit(train_data) # Make predictions on test data predictions_rf = rf_model.transform(test_data) # Evaluate Random Forest Model binary_evaluator = BinaryClassificationEvaluator(labelCol="Class", rawPredictionCol="rawPrediction", metricName="areaUnderROC") pr_evaluator = BinaryClassificationEvaluator(labelCol="Class", rawPredictionCol="rawPrediction", metricName="areaUnderPR") auc_rf = binary_evaluator.evaluate(predictions_rf) auprc_rf = pr_evaluator.evaluate(predictions_rf) print(f"Random Forest - AUC: {auc_rf:.4f}, AUPRC: {auprc_rf:.4f}") # UDF to extract positive probability from probability vector extract_prob = udf(lambda prob: float(prob[1]), FloatType()) predictions_rf = predictions_rf.withColumn("positive_probability", extract_prob(col("probability")))
Step 5 Function to calculate precision, recall, and F1-score
# Function to calculate precision, recall, and F1-score def calculate_metrics(predictions): tp = predictions.filter((col("Class") == 1) & (col("prediction") == 1)).count() fp = predictions.filter((col("Class") == 0) & (col("prediction") == 1)).count() fn = predictions.filter((col("Class") == 1) & (col("prediction") == 0)).count() precision = tp / (tp + fp) if (tp + fp) != 0 else 0 recall = tp / (tp + fn) if (tp + fn) != 0 else 0 f1_score = (2 * precision * recall) / (precision + recall) if (precision + recall) != 0 else 0 return precision, recall, f1_score
Step 6: Find the best threshold for the model
# Find the best threshold for the model best_threshold = 0.5 best_f1 = 0 for threshold in np.arange(0.1, 0.9, 0.1): thresholded_predictions = predictions_rf.withColumn("prediction", (col("positive_probability") > threshold).cast("double")) precision, recall, f1 = calculate_metrics(thresholded_predictions) if f1 > best_f1: best_f1 = f1 best_threshold = threshold print(f"Best threshold: {best_threshold}, Best F1-score: {best_f1:.4f}")
Step7: Evaluate on unseen Data
# Evaluate on unseen data predictions_unseen = rf_model.transform(unseen_data) auc_unseen = binary_evaluator.evaluate(predictions_unseen) print(f"Unseen Data - AUC: {auc_unseen:.4f}") precision, recall, f1 = calculate_metrics(predictions_unseen) print(f"Unseen Data - Precision: {precision:.4f}, Recall: {recall:.4f}, F1-score: {f1:.4f}") area_under_roc = binary_evaluator.evaluate(predictions_unseen) area_under_pr = pr_evaluator.evaluate(predictions_unseen) print(f"Unseen Data - AUC: {area_under_roc:.4f}, AUPRC: {area_under_pr:.4f}")
RESULTS
Best threshold: 0.30000000000000004, Best F1-score: 0.9062 Unseen Data - AUC: 0.9384 Unseen Data - Precision: 0.9655, Recall: 0.7568, F1-score: 0.8485 Unseen Data - AUC: 0.9423, AUPRC: 0.8618
You can then save this model (few KBs) and use it anywehere in pyspark pipeline
rf_model.save()
Here’s why PySpark makes a huge difference when dealing with large datasets in real-world machine learning tasks:
It Scales Easily: PySpark can distribute tasks across clusters, allowing you to process terabytes of data without running out of memory.
On-the-fly Data Processing: PySpark doesn’t need to load the entire dataset into memory. It processes the data as needed, which makes it much more efficient.
Faster Model Training: With distributed computing, you can train models faster by distributing the computational workload across multiple machines.
Final Thoughts
PyTorch and TensorFlow are fantastic tools for building machine learning models, but for real-world, large-scale tasks, you need more. Distributed computing with PySpark allows you to handle huge datasets efficiently, process data in real-time, and scale your machine learning pipelines.
So, the next time you’re working with massive data—whether it’s fraud detection, recommendation systems, or financial analysis—consider using PySpark to take your project to the next level.
For the full code and results, check out this notebook. :
https://colab.research.google.com/drive/1W9naxNZirirLRodSEnHAUWevYd5LH8D4?authuser=5#scrollTo=odmodmqKcY23
__
我是 Swapnil,请随意留下您的评论、结果和想法,或者联系我 - swapnil@nooffice.no 获取数据、软件开发工作和工作
The above is the detailed content of Why You Need Distributed Computing for Real-World Machine Learning. For more information, please follow other related articles on the PHP Chinese website!

This tutorial demonstrates how to use Python to process the statistical concept of Zipf's law and demonstrates the efficiency of Python's reading and sorting large text files when processing the law. You may be wondering what the term Zipf distribution means. To understand this term, we first need to define Zipf's law. Don't worry, I'll try to simplify the instructions. Zipf's Law Zipf's law simply means: in a large natural language corpus, the most frequently occurring words appear about twice as frequently as the second frequent words, three times as the third frequent words, four times as the fourth frequent words, and so on. Let's look at an example. If you look at the Brown corpus in American English, you will notice that the most frequent word is "th

This article explains how to use Beautiful Soup, a Python library, to parse HTML. It details common methods like find(), find_all(), select(), and get_text() for data extraction, handling of diverse HTML structures and errors, and alternatives (Sel

This article compares TensorFlow and PyTorch for deep learning. It details the steps involved: data preparation, model building, training, evaluation, and deployment. Key differences between the frameworks, particularly regarding computational grap

Serialization and deserialization of Python objects are key aspects of any non-trivial program. If you save something to a Python file, you do object serialization and deserialization if you read the configuration file, or if you respond to an HTTP request. In a sense, serialization and deserialization are the most boring things in the world. Who cares about all these formats and protocols? You want to persist or stream some Python objects and retrieve them in full at a later time. This is a great way to see the world on a conceptual level. However, on a practical level, the serialization scheme, format or protocol you choose may determine the speed, security, freedom of maintenance status, and other aspects of the program

Python's statistics module provides powerful data statistical analysis capabilities to help us quickly understand the overall characteristics of data, such as biostatistics and business analysis. Instead of looking at data points one by one, just look at statistics such as mean or variance to discover trends and features in the original data that may be ignored, and compare large datasets more easily and effectively. This tutorial will explain how to calculate the mean and measure the degree of dispersion of the dataset. Unless otherwise stated, all functions in this module support the calculation of the mean() function instead of simply summing the average. Floating point numbers can also be used. import random import statistics from fracti

In this tutorial you'll learn how to handle error conditions in Python from a whole system point of view. Error handling is a critical aspect of design, and it crosses from the lowest levels (sometimes the hardware) all the way to the end users. If y

The article discusses popular Python libraries like NumPy, Pandas, Matplotlib, Scikit-learn, TensorFlow, Django, Flask, and Requests, detailing their uses in scientific computing, data analysis, visualization, machine learning, web development, and H

This tutorial builds upon the previous introduction to Beautiful Soup, focusing on DOM manipulation beyond simple tree navigation. We'll explore efficient search methods and techniques for modifying HTML structure. One common DOM search method is ex


Hot AI Tools

Undresser.AI Undress
AI-powered app for creating realistic nude photos

AI Clothes Remover
Online AI tool for removing clothes from photos.

Undress AI Tool
Undress images for free

Clothoff.io
AI clothes remover

AI Hentai Generator
Generate AI Hentai for free.

Hot Article

Hot Tools

Dreamweaver Mac version
Visual web development tools

VSCode Windows 64-bit Download
A free and powerful IDE editor launched by Microsoft

MinGW - Minimalist GNU for Windows
This project is in the process of being migrated to osdn.net/projects/mingw, you can continue to follow us there. MinGW: A native Windows port of the GNU Compiler Collection (GCC), freely distributable import libraries and header files for building native Windows applications; includes extensions to the MSVC runtime to support C99 functionality. All MinGW software can run on 64-bit Windows platforms.

PhpStorm Mac version
The latest (2018.2.1) professional PHP integrated development tool

SAP NetWeaver Server Adapter for Eclipse
Integrate Eclipse with SAP NetWeaver application server.
