Heim  >  Artikel  >  Technologie-Peripheriegeräte  >  Maschinelles Lernen mit den Linfa- und Polars-Bibliotheken von Rust: Lineare Regression

Maschinelles Lernen mit den Linfa- und Polars-Bibliotheken von Rust: Lineare Regression

WBOY
WBOYnach vorne
2024-03-01 17:25:021117Durchsuche

Maschinelles Lernen mit den Linfa- und Polars-Bibliotheken von Rust: Lineare Regression

In diesem Artikel werden wir die Linfa-Bibliothek und die Polars-Bibliothek von Rust verwenden, um den linearen Regressionsalgorithmus beim maschinellen Lernen zu implementieren.

Die Linfa-Kiste zielt darauf ab, ein umfassendes Toolkit für die Erstellung von Anwendungen für maschinelles Lernen mit Rust bereitzustellen.

Polars ist eine Rust DataFrame-Bibliothek, die auf dem Apache Arrow-Speichermodell basiert. Apache Arrow bietet effiziente Spaltendatenstrukturen und hat sich nach und nach zum De-facto-Standard entwickelt.

Im folgenden Beispiel verwenden wir einen Diabetes-Datensatz, um einen linearen Regressionsalgorithmus zu trainieren.

Erstellen Sie ein neues Rust-Projekt mit dem folgenden Befehl:

cargo new machine_learning_linfa

Fügen Sie die folgenden Abhängigkeiten in der Datei Cargo.toml hinzu:

[dependencies]linfa = "0.7.0"linfa-linear = "0.7.0"ndarray = "0.15.6"polars = { version = "0.35.4", features = ["ndarray"]}

Erstellen Sie eine diabetes_file.csv-Datei im Projektstammverzeichnis und schreiben Sie den Datensatz in die Datei .

AGESEX BMI BPS1S2S3S4S5S6Y592 32.1101 157 93.2384 4.859887151481 21.687183 103.2 703 3.89186975722 30.593156 93.6414 4.672885141241 25.384198 131.4 405 4.890389206501 23101 192 125.4 524 4.290580135231 22.689139 64.8612 4.18976897362 2290160 99.6503 3.951282138662 26.2114 255 185 564.554.24859263602 32.183179 119.4 424 4.477394110.............

Laden Sie den Datensatz hier herunter: https://www4.stat.ncsu.edu/~boos/var.select/diabetes.tab.txt

Schreiben Sie Folgendes in src/main.rs Dateicode:

use linfa::prelude::*;use linfa::traits::Fit;use linfa_linear::LinearRegression;use ndarray::{ArrayBase, OwnedRepr};use polars::prelude::*; // Import polarsfn main() -> Result> {// 将制表符定义为分隔符let separator = b'\t';let df = polars::prelude::CsvReader::from_path("./diabetes_file.csv")?.infer_schema(None).with_separator(separator).has_header(true).finish()?;println!("{:?}", df);// 提取并转换目标列let age_series = df.column("AGE")?.cast(&DataType::Float64)?;let target = age_series.f64()?;println!("Creating features dataset");let mut features = df.drop("AGE")?;// 遍历列并将每个列强制转换为Float64for col_name in features.get_column_names_owned() {let casted_col = df.column(&col_name)?.cast(&DataType::Float64).expect("Failed to cast column");features.with_column(casted_col)?;}println!("{:?}", df);let features_ndarray: ArrayBase<ownedrepr>, _> =features.to_ndarray::<float64type>(IndexOrder::C)?;let target_ndarray = target.to_ndarray()?.to_owned();let (dataset_training, dataset_validation) =Dataset::new(features_ndarray, target_ndarray).split_with_ratio(0.80);// 训练模型let model = LinearRegression::default().fit(&dataset_training)?;// 预测let pred = model.predict(&dataset_validation);// 评价模型let r2 = pred.r2(&dataset_validation)?;println!("r2 from prediction: {}", r2);Ok(())}</float64type></ownedrepr>

  • Verwenden Sie den CSV-Reader von Polar, um die CSV-Datei zu lesen.
  • Datenrahmen zur Überprüfung auf die Konsole drucken.
  • Extrahieren Sie die Spalte „AGE“ aus dem DataFrame als Zielvariable für die lineare Regression. Wandeln Sie die Zielspalte in Float64 (Gleitkommazahl mit doppelter Genauigkeit) um, ein gängiges Format für numerische Daten beim maschinellen Lernen.
  • Konvertieren Sie Features DataFrame in narray::ArrayBase (ein mehrdimensionales Array), um die Kompatibilität mit Linfa zu gewährleisten. Konvertieren Sie Zielsequenzen in Arrays, die mit der Linfa-Bibliothek für maschinelles Lernen kompatibel sind.
  • Teilen Sie den Datensatz im Verhältnis 80:20 in Trainings- und Validierungssätze auf, was beim maschinellen Lernen eine gängige Praxis zur Bewertung von Modellen anhand unbekannter Daten ist.
  • Verwenden Sie den linearen Regressionsalgorithmus von Linfa, um ein lineares Regressionsmodell anhand des Trainingsdatensatzes zu trainieren.
  • Verwenden Sie das trainierte Modell, um Vorhersagen zum Validierungsdatensatz zu treffen.
  • Berechnen Sie das R²-Maß (Bestimmungskoeffizient) für den Validierungsdatensatz, um die Leistung des Modells zu bewerten. Der R²-Wert gibt an, wie gut sich die Regressionsvorhersagen den tatsächlichen Datenpunkten annähern.

Führen Sie den Cargo-Lauf aus, und die laufenden Ergebnisse lauten wie folgt:

shape: (442, 11)┌─────┬─────┬──────┬───────┬───┬──────┬────────┬─────┬─────┐│ AGE ┆ SEX ┆ BMI┆ BP┆ … ┆ S4 ┆ S5 ┆ S6┆ Y ││ --- ┆ --- ┆ ---┆ --- ┆ ┆ ---┆ ---┆ --- ┆ --- ││ i64 ┆ i64 ┆ f64┆ f64 ┆ ┆ f64┆ f64┆ i64 ┆ i64 │╞═════╪═════╪══════╪═══════╪═══╪══════╪════════╪═════╪═════╡│ 59┆ 2 ┆ 32.1 ┆ 101.0 ┆ … ┆ 4.0┆ 4.8598 ┆ 87┆ 151 ││ 48┆ 1 ┆ 21.6 ┆ 87.0┆ … ┆ 3.0┆ 3.8918 ┆ 69┆ 75││ 72┆ 2 ┆ 30.5 ┆ 93.0┆ … ┆ 4.0┆ 4.6728 ┆ 85┆ 141 ││ 24┆ 1 ┆ 25.3 ┆ 84.0┆ … ┆ 5.0┆ 4.8903 ┆ 89┆ 206 ││ … ┆ … ┆ …┆ … ┆ … ┆ …┆ …┆ … ┆ … ││ 47┆ 2 ┆ 24.9 ┆ 75.0┆ … ┆ 5.0┆ 4.4427 ┆ 102 ┆ 104 ││ 60┆ 2 ┆ 24.9 ┆ 99.67 ┆ … ┆ 3.77 ┆ 4.1271 ┆ 95┆ 132 ││ 36┆ 1 ┆ 30.0 ┆ 95.0┆ … ┆ 4.79 ┆ 5.1299 ┆ 85┆ 220 ││ 36┆ 1 ┆ 19.6 ┆ 71.0┆ … ┆ 3.0┆ 4.5951 ┆ 92┆ 57│└─────┴─────┴──────┴───────┴───┴──────┴────────┴─────┴─────┘Creating features datasetshape: (442, 11)┌─────┬─────┬──────┬───────┬───┬──────┬────────┬─────┬─────┐│ AGE ┆ SEX ┆ BMI┆ BP┆ … ┆ S4 ┆ S5 ┆ S6┆ Y ││ --- ┆ --- ┆ ---┆ --- ┆ ┆ ---┆ ---┆ --- ┆ --- ││ i64 ┆ i64 ┆ f64┆ f64 ┆ ┆ f64┆ f64┆ i64 ┆ i64 │╞═════╪═════╪══════╪═══════╪═══╪══════╪════════╪═════╪═════╡│ 59┆ 2 ┆ 32.1 ┆ 101.0 ┆ … ┆ 4.0┆ 4.8598 ┆ 87┆ 151 ││ 48┆ 1 ┆ 21.6 ┆ 87.0┆ … ┆ 3.0┆ 3.8918 ┆ 69┆ 75││ 72┆ 2 ┆ 30.5 ┆ 93.0┆ … ┆ 4.0┆ 4.6728 ┆ 85┆ 141 ││ 24┆ 1 ┆ 25.3 ┆ 84.0┆ … ┆ 5.0┆ 4.8903 ┆ 89┆ 206 ││ … ┆ … ┆ …┆ … ┆ … ┆ …┆ …┆ … ┆ … ││ 47┆ 2 ┆ 24.9 ┆ 75.0┆ … ┆ 5.0┆ 4.4427 ┆ 102 ┆ 104 ││ 60┆ 2 ┆ 24.9 ┆ 99.67 ┆ … ┆ 3.77 ┆ 4.1271 ┆ 95┆ 132 ││ 36┆ 1 ┆ 30.0 ┆ 95.0┆ … ┆ 4.79 ┆ 5.1299 ┆ 85┆ 220 ││ 36┆ 1 ┆ 19.6 ┆ 71.0┆ … ┆ 3.0┆ 4.5951 ┆ 92┆ 57│└─────┴─────┴──────┴───────┴───┴──────┴────────┴─────┴─────┘r2 from prediction: 0.15937814745521017

Für Datenwissenschaftler, die schnelle Iteration und schnelles Prototyping priorisieren, kann die Kompilierungszeit von Rust Kopfschmerzen bereiten. Rusts starkes statisches Typsystem ist zwar gut für die Gewährleistung der Typsicherheit und die Reduzierung von Laufzeitfehlern geeignet, erhöht aber auch die Komplexität des Codierungsprozesses.

Das obige ist der detaillierte Inhalt vonMaschinelles Lernen mit den Linfa- und Polars-Bibliotheken von Rust: Lineare Regression. Für weitere Informationen folgen Sie bitte anderen verwandten Artikeln auf der PHP chinesischen Website!

Stellungnahme:
Dieser Artikel ist reproduziert unter:51cto.com. Bei Verstößen wenden Sie sich bitte an admin@php.cn löschen