ホームページ >テクノロジー周辺機器 >AI >Rust の Linfa ライブラリと Polars ライブラリを使用した機械学習: 線形回帰

Rust の Linfa ライブラリと Polars ライブラリを使用した機械学習: 線形回帰

WBOY
WBOY転載
2024-03-01 17:25:021194ブラウズ

Rust の Linfa ライブラリと Polars ライブラリを使用した機械学習: 線形回帰

この記事では、Rust の Linfa ライブラリと Polars ライブラリを使用して、機械学習に線形回帰アルゴリズムを実装します。

Linfa クレートは、Rust を使用して機械学習アプリケーションを構築するための包括的なツールキットを提供することを目的としています。

Polars は、Apache Arrow メモリ モデルに基づく Rust DataFrame ライブラリです。 Apache Arrow は効率的な列データ構造を提供し、徐々に事実上の標準になりました。

以下の例では、糖尿病データセットを使用して線形回帰アルゴリズムをトレーニングします。

次のコマンドを使用して新しい Rust プロジェクトを作成します:

cargo new machine_learning_linfa

次の依存関係を Cargo.toml ファイルに追加します:

[dependencies]linfa = "0.7.0"linfa-linear = "0.7.0"ndarray = "0.15.6"polars = { version = "0.35.4", features = ["ndarray"]}

プロジェクトのルート ディレクトリに糖尿病_file.csv ファイルを作成し、データ セットをファイルに書き込みます。

#
AGESEX BMI BPS1S2S3S4S5S6Y592 32.1101 157 93.2384 4.859887151481 21.687183 103.2 703 3.89186975722 30.593156 93.6414 4.672885141241 25.384198 131.4 405 4.890389206501 23101 192 125.4 524 4.290580135231 22.689139 64.8612 4.18976897362 2290160 99.6503 3.951282138662 26.2114 255 185 564.554.24859263602 32.183179 119.4 424 4.477394110.............
データセットはここからダウンロードできます: https://www4.stat.ncsu.edu/~boos/var.select/diabetes.tab.txt

src/main.rs ファイルに次のコードを記述します。

use linfa::prelude::*;use linfa::traits::Fit;use linfa_linear::LinearRegression;use ndarray::{ArrayBase, OwnedRepr};use polars::prelude::*; // Import polarsfn main() -> Result> {// 将制表符定义为分隔符let separator = b'\t';let df = polars::prelude::CsvReader::from_path("./diabetes_file.csv")?.infer_schema(None).with_separator(separator).has_header(true).finish()?;println!("{:?}", df);// 提取并转换目标列let age_series = df.column("AGE")?.cast(&DataType::Float64)?;let target = age_series.f64()?;println!("Creating features dataset");let mut features = df.drop("AGE")?;// 遍历列并将每个列强制转换为Float64for col_name in features.get_column_names_owned() {let casted_col = df.column(&col_name)?.cast(&DataType::Float64).expect("Failed to cast column");features.with_column(casted_col)?;}println!("{:?}", df);let features_ndarray: ArrayBase<ownedrepr>, _> =features.to_ndarray::<float64type>(IndexOrder::C)?;let target_ndarray = target.to_ndarray()?.to_owned();let (dataset_training, dataset_validation) =Dataset::new(features_ndarray, target_ndarray).split_with_ratio(0.80);// 训练模型let model = LinearRegression::default().fit(&dataset_training)?;// 预测let pred = model.predict(&dataset_validation);// 评价模型let r2 = pred.r2(&dataset_validation)?;println!("r2 from prediction: {}", r2);Ok(())}</float64type></ownedrepr>

    polar の CSV リーダーを使用して、 CSVファイル。
  • 検査のためにデータフレームをコンソールに出力します。
  • 線形回帰のターゲット変数としてデータフレームから「AGE」列を抽出します。ターゲット列を Float64 (倍精度浮動小数点数) にキャストします。これは、機械学習における数値データの一般的な形式です。
  • linfa との互換性を確保するために、機能 DataFrame を narray::ArrayBase (多次元配列) に変換します。ターゲット シーケンスを、機械学習用の linfa ライブラリと互換性のある配列に変換します。
  • 80 対 20 の比率を使用してデータ セットをトレーニング セットと検証セットに分割します。これは、未知のデータのモデルを評価するための機械学習における一般的な方法です。
  • linfa の線形回帰アルゴリズムを使用して、トレーニング データ セットで線形回帰モデルをトレーニングします。
  • トレーニングされたモデルを使用して検証データセットを予測します。
  • 検証データセットの R² (決定係数) メジャーを計算して、モデルのパフォーマンスを評価します。 R² 値は、回帰予測が実際のデータ ポイントにどの程度近似しているかを示します。
カーゴ実行を実行すると、結果は次のようになります:

shape: (442, 11)┌─────┬─────┬──────┬───────┬───┬──────┬────────┬─────┬─────┐│ AGE ┆ SEX ┆ BMI┆ BP┆ … ┆ S4 ┆ S5 ┆ S6┆ Y ││ --- ┆ --- ┆ ---┆ --- ┆ ┆ ---┆ ---┆ --- ┆ --- ││ i64 ┆ i64 ┆ f64┆ f64 ┆ ┆ f64┆ f64┆ i64 ┆ i64 │╞═════╪═════╪══════╪═══════╪═══╪══════╪════════╪═════╪═════╡│ 59┆ 2 ┆ 32.1 ┆ 101.0 ┆ … ┆ 4.0┆ 4.8598 ┆ 87┆ 151 ││ 48┆ 1 ┆ 21.6 ┆ 87.0┆ … ┆ 3.0┆ 3.8918 ┆ 69┆ 75││ 72┆ 2 ┆ 30.5 ┆ 93.0┆ … ┆ 4.0┆ 4.6728 ┆ 85┆ 141 ││ 24┆ 1 ┆ 25.3 ┆ 84.0┆ … ┆ 5.0┆ 4.8903 ┆ 89┆ 206 ││ … ┆ … ┆ …┆ … ┆ … ┆ …┆ …┆ … ┆ … ││ 47┆ 2 ┆ 24.9 ┆ 75.0┆ … ┆ 5.0┆ 4.4427 ┆ 102 ┆ 104 ││ 60┆ 2 ┆ 24.9 ┆ 99.67 ┆ … ┆ 3.77 ┆ 4.1271 ┆ 95┆ 132 ││ 36┆ 1 ┆ 30.0 ┆ 95.0┆ … ┆ 4.79 ┆ 5.1299 ┆ 85┆ 220 ││ 36┆ 1 ┆ 19.6 ┆ 71.0┆ … ┆ 3.0┆ 4.5951 ┆ 92┆ 57│└─────┴─────┴──────┴───────┴───┴──────┴────────┴─────┴─────┘Creating features datasetshape: (442, 11)┌─────┬─────┬──────┬───────┬───┬──────┬────────┬─────┬─────┐│ AGE ┆ SEX ┆ BMI┆ BP┆ … ┆ S4 ┆ S5 ┆ S6┆ Y ││ --- ┆ --- ┆ ---┆ --- ┆ ┆ ---┆ ---┆ --- ┆ --- ││ i64 ┆ i64 ┆ f64┆ f64 ┆ ┆ f64┆ f64┆ i64 ┆ i64 │╞═════╪═════╪══════╪═══════╪═══╪══════╪════════╪═════╪═════╡│ 59┆ 2 ┆ 32.1 ┆ 101.0 ┆ … ┆ 4.0┆ 4.8598 ┆ 87┆ 151 ││ 48┆ 1 ┆ 21.6 ┆ 87.0┆ … ┆ 3.0┆ 3.8918 ┆ 69┆ 75││ 72┆ 2 ┆ 30.5 ┆ 93.0┆ … ┆ 4.0┆ 4.6728 ┆ 85┆ 141 ││ 24┆ 1 ┆ 25.3 ┆ 84.0┆ … ┆ 5.0┆ 4.8903 ┆ 89┆ 206 ││ … ┆ … ┆ …┆ … ┆ … ┆ …┆ …┆ … ┆ … ││ 47┆ 2 ┆ 24.9 ┆ 75.0┆ … ┆ 5.0┆ 4.4427 ┆ 102 ┆ 104 ││ 60┆ 2 ┆ 24.9 ┆ 99.67 ┆ … ┆ 3.77 ┆ 4.1271 ┆ 95┆ 132 ││ 36┆ 1 ┆ 30.0 ┆ 95.0┆ … ┆ 4.79 ┆ 5.1299 ┆ 85┆ 220 ││ 36┆ 1 ┆ 19.6 ┆ 71.0┆ … ┆ 3.0┆ 4.5951 ┆ 92┆ 57│└─────┴─────┴──────┴───────┴───┴──────┴────────┴─────┴─────┘r2 from prediction: 0.15937814745521017

迅速なイテレーションと迅速なプロトタイピングを優先するデータ サイエンティスト向けに、Rust のコンパイル時間は頭痛の種になる可能性があります。 Rust の強力な静的型システムは、型の安全性を確保し、実行時エラーを削減するのに適していますが、コーディング プロセスの複雑さをさらに高めます。

以上がRust の Linfa ライブラリと Polars ライブラリを使用した機械学習: 線形回帰の詳細内容です。詳細については、PHP 中国語 Web サイトの他の関連記事を参照してください。

声明:
この記事は51cto.comで複製されています。侵害がある場合は、admin@php.cn までご連絡ください。