>기술 주변기기 >일체 포함 >Rust의 Linfa 및 Polars 라이브러리를 사용한 기계 학습: 선형 회귀

Rust의 Linfa 및 Polars 라이브러리를 사용한 기계 학습: 선형 회귀

WBOY
WBOY앞으로
2024-03-01 17:25:021134검색

Rust의 Linfa 및 Polars 라이브러리를 사용한 기계 학습: 선형 회귀

이 기사에서는 Rust의 Linfa 라이브러리와 Polars 라이브러리를 사용하여 기계 학습에서 선형 회귀 알고리즘을 구현하겠습니다.

Linfa 크레이트는 Rust를 사용하여 기계 학습 애플리케이션을 구축하기 위한 포괄적인 툴킷을 제공하는 것을 목표로 합니다.

Polars는 Apache Arrow 메모리 모델을 기반으로 하는 Rust DataFrame 라이브러리입니다. Apache Arrow는 효율적인 열 데이터 구조를 제공하며 점차 사실상의 표준이 되었습니다.

아래 예에서는 당뇨병 데이터 세트를 사용하여 선형 회귀 알고리즘을 교육합니다.

다음 명령을 사용하여 새 Rust 프로젝트를 생성합니다.

cargo new machine_learning_linfa

Cargo.toml 파일에 다음 종속성을 추가합니다.

[dependencies]linfa = "0.7.0"linfa-linear = "0.7.0"ndarray = "0.15.6"polars = { version = "0.35.4", features = ["ndarray"]}

프로젝트 루트 디렉터리에 당뇨병_file.csv 파일을 생성하고 데이터세트를 파일 .

AGESEX BMI BPS1S2S3S4S5S6Y592 32.1101 157 93.2384 4.859887151481 21.687183 103.2 703 3.89186975722 30.593156 93.6414 4.672885141241 25.384198 131.4 405 4.890389206501 23101 192 125.4 524 4.290580135231 22.689139 64.8612 4.18976897362 2290160 99.6503 3.951282138662 26.2114 255 185 564.554.24859263602 32.183179 119.4 424 4.477394110.............

여기에서 데이터 세트를 다운로드하세요: https://www4.stat.ncsu.edu/~boos/var.select/diabetes.tab.txt

src/main.rs에 다음을 작성하세요 파일 코드:

use linfa::prelude::*;use linfa::traits::Fit;use linfa_linear::LinearRegression;use ndarray::{ArrayBase, OwnedRepr};use polars::prelude::*; // Import polarsfn main() -> Result> {// 将制表符定义为分隔符let separator = b'\t';let df = polars::prelude::CsvReader::from_path("./diabetes_file.csv")?.infer_schema(None).with_separator(separator).has_header(true).finish()?;println!("{:?}", df);// 提取并转换目标列let age_series = df.column("AGE")?.cast(&DataType::Float64)?;let target = age_series.f64()?;println!("Creating features dataset");let mut features = df.drop("AGE")?;// 遍历列并将每个列强制转换为Float64for col_name in features.get_column_names_owned() {let casted_col = df.column(&col_name)?.cast(&DataType::Float64).expect("Failed to cast column");features.with_column(casted_col)?;}println!("{:?}", df);let features_ndarray: ArrayBase<ownedrepr>, _> =features.to_ndarray::<float64type>(IndexOrder::C)?;let target_ndarray = target.to_ndarray()?.to_owned();let (dataset_training, dataset_validation) =Dataset::new(features_ndarray, target_ndarray).split_with_ratio(0.80);// 训练模型let model = LinearRegression::default().fit(&dataset_training)?;// 预测let pred = model.predict(&dataset_validation);// 评价模型let r2 = pred.r2(&dataset_validation)?;println!("r2 from prediction: {}", r2);Ok(())}</float64type></ownedrepr>

  • CSV 파일을 읽으려면 polar의 CSV 리더를 사용하세요.
  • 검사를 위해 데이터프레임을 콘솔에 인쇄하세요.
  • DataFrame에서 "AGE" 열을 선형 회귀의 대상 변수로 추출합니다. 대상 열을 기계 학습의 숫자 데이터에 대한 일반적인 형식인 Float64(이중 정밀도 부동 소수점 수)로 캐스팅합니다.
  • Convert는 linfa와의 호환성을 위해 DataFrame을 narray::ArrayBase(다차원 배열)로 변환하는 기능을 제공합니다. 기계 학습을 위해 linfa 라이브러리와 호환되는 배열로 대상 시퀀스를 변환합니다.
  • 알 수 없는 데이터에 대한 모델을 평가하기 위한 기계 학습의 일반적인 관행인 80-20 비율을 사용하여 데이터 세트를 훈련 및 검증 세트로 분할합니다.
  • linfa의 선형 회귀 알고리즘을 사용하여 훈련 데이터 세트에서 선형 회귀 모델을 훈련합니다.
  • 훈련된 모델을 사용하여 검증 데이터 세트에 대해 예측합니다.
  • 검증 데이터 세트에서 R²(결정 계수) 측정값을 계산하여 모델 성능을 평가합니다. R² 값은 회귀 예측이 실제 데이터 포인트에 얼마나 근접한지를 나타냅니다.

Execute Cargo Run의 실행 결과는 다음과 같습니다.

shape: (442, 11)┌─────┬─────┬──────┬───────┬───┬──────┬────────┬─────┬─────┐│ AGE ┆ SEX ┆ BMI┆ BP┆ … ┆ S4 ┆ S5 ┆ S6┆ Y ││ --- ┆ --- ┆ ---┆ --- ┆ ┆ ---┆ ---┆ --- ┆ --- ││ i64 ┆ i64 ┆ f64┆ f64 ┆ ┆ f64┆ f64┆ i64 ┆ i64 │╞═════╪═════╪══════╪═══════╪═══╪══════╪════════╪═════╪═════╡│ 59┆ 2 ┆ 32.1 ┆ 101.0 ┆ … ┆ 4.0┆ 4.8598 ┆ 87┆ 151 ││ 48┆ 1 ┆ 21.6 ┆ 87.0┆ … ┆ 3.0┆ 3.8918 ┆ 69┆ 75││ 72┆ 2 ┆ 30.5 ┆ 93.0┆ … ┆ 4.0┆ 4.6728 ┆ 85┆ 141 ││ 24┆ 1 ┆ 25.3 ┆ 84.0┆ … ┆ 5.0┆ 4.8903 ┆ 89┆ 206 ││ … ┆ … ┆ …┆ … ┆ … ┆ …┆ …┆ … ┆ … ││ 47┆ 2 ┆ 24.9 ┆ 75.0┆ … ┆ 5.0┆ 4.4427 ┆ 102 ┆ 104 ││ 60┆ 2 ┆ 24.9 ┆ 99.67 ┆ … ┆ 3.77 ┆ 4.1271 ┆ 95┆ 132 ││ 36┆ 1 ┆ 30.0 ┆ 95.0┆ … ┆ 4.79 ┆ 5.1299 ┆ 85┆ 220 ││ 36┆ 1 ┆ 19.6 ┆ 71.0┆ … ┆ 3.0┆ 4.5951 ┆ 92┆ 57│└─────┴─────┴──────┴───────┴───┴──────┴────────┴─────┴─────┘Creating features datasetshape: (442, 11)┌─────┬─────┬──────┬───────┬───┬──────┬────────┬─────┬─────┐│ AGE ┆ SEX ┆ BMI┆ BP┆ … ┆ S4 ┆ S5 ┆ S6┆ Y ││ --- ┆ --- ┆ ---┆ --- ┆ ┆ ---┆ ---┆ --- ┆ --- ││ i64 ┆ i64 ┆ f64┆ f64 ┆ ┆ f64┆ f64┆ i64 ┆ i64 │╞═════╪═════╪══════╪═══════╪═══╪══════╪════════╪═════╪═════╡│ 59┆ 2 ┆ 32.1 ┆ 101.0 ┆ … ┆ 4.0┆ 4.8598 ┆ 87┆ 151 ││ 48┆ 1 ┆ 21.6 ┆ 87.0┆ … ┆ 3.0┆ 3.8918 ┆ 69┆ 75││ 72┆ 2 ┆ 30.5 ┆ 93.0┆ … ┆ 4.0┆ 4.6728 ┆ 85┆ 141 ││ 24┆ 1 ┆ 25.3 ┆ 84.0┆ … ┆ 5.0┆ 4.8903 ┆ 89┆ 206 ││ … ┆ … ┆ …┆ … ┆ … ┆ …┆ …┆ … ┆ … ││ 47┆ 2 ┆ 24.9 ┆ 75.0┆ … ┆ 5.0┆ 4.4427 ┆ 102 ┆ 104 ││ 60┆ 2 ┆ 24.9 ┆ 99.67 ┆ … ┆ 3.77 ┆ 4.1271 ┆ 95┆ 132 ││ 36┆ 1 ┆ 30.0 ┆ 95.0┆ … ┆ 4.79 ┆ 5.1299 ┆ 85┆ 220 ││ 36┆ 1 ┆ 19.6 ┆ 71.0┆ … ┆ 3.0┆ 4.5951 ┆ 92┆ 57│└─────┴─────┴──────┴───────┴───┴──────┴────────┴─────┴─────┘r2 from prediction: 0.15937814745521017

빠른 반복과 빠른 프로토타이핑을 우선시하는 데이터 과학자에게 Rust의 컴파일 시간은 골칫거리가 될 수 있습니다. Rust의 강력한 정적 유형 시스템은 유형 안전성을 보장하고 런타임 오류를 줄이는 데 좋지만 코딩 프로세스에 복잡성을 추가합니다.

위 내용은 Rust의 Linfa 및 Polars 라이브러리를 사용한 기계 학습: 선형 회귀의 상세 내용입니다. 자세한 내용은 PHP 중국어 웹사이트의 기타 관련 기사를 참조하세요!

성명:
이 기사는 51cto.com에서 복제됩니다. 침해가 있는 경우 admin@php.cn으로 문의하시기 바랍니다. 삭제