Home >Technology peripherals >AI >Machine Learning with Rust's Linfa and Polars Libraries: Linear Regression
In this article, we will use Rust’s Linfa library and Polars library to implement the linear regression algorithm in machine learning.
The Linfa crate aims to provide a comprehensive toolkit for building machine learning applications using Rust.
Polars is a Rust DataFrame library based on the Apache Arrow memory model. Apache Arrow provides efficient column data structures and has gradually become the de facto standard.
In the example below, we use a diabetes data set to train a linear regression algorithm.
Create a new Rust project using the following command:
cargo new machine_learning_linfa
Add the following dependencies to the Cargo.toml file:
[dependencies]linfa = "0.7.0"linfa-linear = "0.7.0"ndarray = "0.15.6"polars = { version = "0.35.4", features = ["ndarray"]}
Create a diabetes_file.csv file in the project root directory and write the data set to the file.
AGESEX BMI BPS1S2S3S4S5S6Y592 32.1101 157 93.2384 4.859887151481 21.687183 103.2 703 3.89186975722 30.593156 93.6414 4.672885141241 25.384198 131.4 405 4.890389206501 23101 192 125.4 524 4.290580135231 22.689139 64.8612 4.18976897362 2290160 99.6503 3.951282138662 26.2114 255 185 564.554.24859263602 32.183179 119.4 424 4.477394110.............
The data set can be downloaded from here: https://www4.stat.ncsu.edu/~boos/var.select/diabetes.tab.txt
Write the following code in the src/main.rs file:
use linfa::prelude::*;use linfa::traits::Fit;use linfa_linear::LinearRegression;use ndarray::{ArrayBase, OwnedRepr};use polars::prelude::*; // Import polarsfn main() -> Result> {// 将制表符定义为分隔符let separator = b'\t';let df = polars::prelude::CsvReader::from_path("./diabetes_file.csv")?.infer_schema(None).with_separator(separator).has_header(true).finish()?;println!("{:?}", df);// 提取并转换目标列let age_series = df.column("AGE")?.cast(&DataType::Float64)?;let target = age_series.f64()?;println!("Creating features dataset");let mut features = df.drop("AGE")?;// 遍历列并将每个列强制转换为Float64for col_name in features.get_column_names_owned() {let casted_col = df.column(&col_name)?.cast(&DataType::Float64).expect("Failed to cast column");features.with_column(casted_col)?;}println!("{:?}", df);let features_ndarray: ArrayBase<ownedrepr>, _> =features.to_ndarray::<float64type>(IndexOrder::C)?;let target_ndarray = target.to_ndarray()?.to_owned();let (dataset_training, dataset_validation) =Dataset::new(features_ndarray, target_ndarray).split_with_ratio(0.80);// 训练模型let model = LinearRegression::default().fit(&dataset_training)?;// 预测let pred = model.predict(&dataset_validation);// 评价模型let r2 = pred.r2(&dataset_validation)?;println!("r2 from prediction: {}", r2);Ok(())}</float64type></ownedrepr>
Execute cargo run, the results are as follows:
shape: (442, 11)┌─────┬─────┬──────┬───────┬───┬──────┬────────┬─────┬─────┐│ AGE ┆ SEX ┆ BMI┆ BP┆ … ┆ S4 ┆ S5 ┆ S6┆ Y ││ --- ┆ --- ┆ ---┆ --- ┆ ┆ ---┆ ---┆ --- ┆ --- ││ i64 ┆ i64 ┆ f64┆ f64 ┆ ┆ f64┆ f64┆ i64 ┆ i64 │╞═════╪═════╪══════╪═══════╪═══╪══════╪════════╪═════╪═════╡│ 59┆ 2 ┆ 32.1 ┆ 101.0 ┆ … ┆ 4.0┆ 4.8598 ┆ 87┆ 151 ││ 48┆ 1 ┆ 21.6 ┆ 87.0┆ … ┆ 3.0┆ 3.8918 ┆ 69┆ 75││ 72┆ 2 ┆ 30.5 ┆ 93.0┆ … ┆ 4.0┆ 4.6728 ┆ 85┆ 141 ││ 24┆ 1 ┆ 25.3 ┆ 84.0┆ … ┆ 5.0┆ 4.8903 ┆ 89┆ 206 ││ … ┆ … ┆ …┆ … ┆ … ┆ …┆ …┆ … ┆ … ││ 47┆ 2 ┆ 24.9 ┆ 75.0┆ … ┆ 5.0┆ 4.4427 ┆ 102 ┆ 104 ││ 60┆ 2 ┆ 24.9 ┆ 99.67 ┆ … ┆ 3.77 ┆ 4.1271 ┆ 95┆ 132 ││ 36┆ 1 ┆ 30.0 ┆ 95.0┆ … ┆ 4.79 ┆ 5.1299 ┆ 85┆ 220 ││ 36┆ 1 ┆ 19.6 ┆ 71.0┆ … ┆ 3.0┆ 4.5951 ┆ 92┆ 57│└─────┴─────┴──────┴───────┴───┴──────┴────────┴─────┴─────┘Creating features datasetshape: (442, 11)┌─────┬─────┬──────┬───────┬───┬──────┬────────┬─────┬─────┐│ AGE ┆ SEX ┆ BMI┆ BP┆ … ┆ S4 ┆ S5 ┆ S6┆ Y ││ --- ┆ --- ┆ ---┆ --- ┆ ┆ ---┆ ---┆ --- ┆ --- ││ i64 ┆ i64 ┆ f64┆ f64 ┆ ┆ f64┆ f64┆ i64 ┆ i64 │╞═════╪═════╪══════╪═══════╪═══╪══════╪════════╪═════╪═════╡│ 59┆ 2 ┆ 32.1 ┆ 101.0 ┆ … ┆ 4.0┆ 4.8598 ┆ 87┆ 151 ││ 48┆ 1 ┆ 21.6 ┆ 87.0┆ … ┆ 3.0┆ 3.8918 ┆ 69┆ 75││ 72┆ 2 ┆ 30.5 ┆ 93.0┆ … ┆ 4.0┆ 4.6728 ┆ 85┆ 141 ││ 24┆ 1 ┆ 25.3 ┆ 84.0┆ … ┆ 5.0┆ 4.8903 ┆ 89┆ 206 ││ … ┆ … ┆ …┆ … ┆ … ┆ …┆ …┆ … ┆ … ││ 47┆ 2 ┆ 24.9 ┆ 75.0┆ … ┆ 5.0┆ 4.4427 ┆ 102 ┆ 104 ││ 60┆ 2 ┆ 24.9 ┆ 99.67 ┆ … ┆ 3.77 ┆ 4.1271 ┆ 95┆ 132 ││ 36┆ 1 ┆ 30.0 ┆ 95.0┆ … ┆ 4.79 ┆ 5.1299 ┆ 85┆ 220 ││ 36┆ 1 ┆ 19.6 ┆ 71.0┆ … ┆ 3.0┆ 4.5951 ┆ 92┆ 57│└─────┴─────┴──────┴───────┴───┴──────┴────────┴─────┴─────┘r2 from prediction: 0.15937814745521017
For data scientists who prioritize rapid iteration and rapid prototyping, Rust's compilation time can be a headache. Rust's strong static type system, while good for ensuring type safety and reducing runtime errors, also adds a layer of complexity to the coding process.
The above is the detailed content of Machine Learning with Rust's Linfa and Polars Libraries: Linear Regression. For more information, please follow other related articles on the PHP Chinese website!