Skip to content

Commit

Permalink
wasm binaries
Browse files Browse the repository at this point in the history
  • Loading branch information
vkomenda committed Aug 11, 2024
1 parent e575e75 commit e6aed0a
Show file tree
Hide file tree
Showing 7 changed files with 55 additions and 49 deletions.
32 changes: 16 additions & 16 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
# Privacy-Preserving Machine Learning with Zero-Knowledge Proofs

### Build and run

```
cargo build --release --target wasm32-wasip1
RUST_LOG=debug cargo run --release
```

The path to the WASM binary may need to be corrected in `zk/src/main.rs`.


### Overview

Welcome to the zKML project, a initiative designed to revolutionize the way data privacy is handled in collaborative machine learning efforts. This project aims to address one of the most pressing challenges in the digital age: enabling data sharing for health or financial predictions while preserving the privacy and security of sensitive information.
Expand Down
47 changes: 20 additions & 27 deletions ml/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use candle_core::{Device, Result, Tensor};
use candle_nn::{AdamW, Linear, Module, Optimizer, ParamsAdamW};
use candle_core::{DType, Device, Tensor};
use candle_nn::{linear, loss::mse, AdamW, Module, Optimizer, ParamsAdamW, VarBuilder, VarMap};
use std::error::Error;

#[no_mangle]
Expand All @@ -11,8 +11,8 @@ pub extern "C" fn _start() {
pub extern "C" fn run() {
match greet_internal() {
Ok(predictions) => {
println!("Model predictions (first 10):");
for (i, pred) in predictions.iter().take(10).enumerate() {
println!("Model predictions:");
for (i, pred) in predictions.iter().enumerate() {
println!("Prediction {}: {:.4}", i + 1, pred);
}
}
Expand All @@ -21,13 +21,6 @@ pub extern "C" fn run() {
}

fn greet_internal() -> Result<Vec<f32>, Box<dyn Error>> {
// let csv_data = r#"survived,pclass,sex,age,sibsp,parch,fare,embarked,class,who,adult_male,deck,embark_town,alive,alone
// 0,3,male,22.0,1,0,7.25,S,Third,man,True,,Southampton,no,False
// 1,1,female,38.0,1,0,71.2833,C,First,woman,False,C,Cherbourg,yes,False
// 1,3,female,26.0,0,0,7.925,S,Third,woman,False,,Southampton,yes,True
// 1,1,female,35.0,1,0,53.1,S,First,woman,False,C,Southampton,yes,False
// 0,3,male,35.0,0,0,8.05,S,Third,man,True,,Southampton,no,True"#;

// longitude,latitude,housing_median_age,total_rooms,total_bedrooms,population,households,median_income,median_house_value,ocean_proximity
let csv_data = r#"-122.27,37.85,40.0,751.0,184.0,409.0,166.0,1.3578,147500.0,NEAR BAY
-118.3,33.95,50.0,1843.0,326.0,892.0,314.0,3.1346,120000.0,<1H OCEAN
Expand All @@ -54,41 +47,41 @@ fn greet_internal() -> Result<Vec<f32>, Box<dyn Error>> {
println!("{features:?}");
println!("{target:?}");

let num_samples = features.len();
let num_features = 8;
let num_samples = features.len() / num_features;

// Convert data to tensors
let features_tensor =
Tensor::from_slice(&features, &[num_samples, num_features], &Device::Cpu)?;
let target_tensor = Tensor::from_slice(&target, &[num_samples, 1], &Device::Cpu)?;

// Step 6: Define the linear regression model
let mut model = Linear::new(num_features, 1);
// Define the linear regression model
let varmap = VarMap::new();
let vb = VarBuilder::from_varmap(&varmap, DType::F32, &Device::Cpu);
let model = linear(num_features, 1, vb.pp("linear"))?;

// Step 7: Set up the optimizer
let mut optimizer = AdamW::new(
vec![],
ParamsAdamW::default(), // OptimizerConfig::adam(0.01).build(model.parameters()),
);
// Set up the optimizer
let params = ParamsAdamW {
lr: 0.1,
..Default::default()
};
let mut optimizer = AdamW::new(vec![], params)?;

// Step 8: Training Loop
// Training Loop
let num_epochs = 100;
for epoch in 0..num_epochs {
let predictions = model.forward(&features_tensor)?;
let loss = mse_loss(&predictions, &target_tensor)?;
optimizer.step(&loss)?;
let loss = mse(&predictions, &target_tensor)?;
optimizer.backward_step(&loss)?;

if epoch % 10 == 0 {
println!("Epoch {}: Loss = {:?}", epoch, loss);
}
}

// Step 9: Make and print predictions
let predictions = model.forward(&features_tensor)?;
println!(
"Model predictions (first 10): {:?}",
predictions.slice([..10])
);
let predictions = model.forward(&features_tensor)?.squeeze(1)?.to_vec1()?;
println!("Model predictions: {predictions:?}");

Ok(predictions)
}
Binary file added wasm/gradient_boosting.wasm
Binary file not shown.
Binary file added wasm/ml.wasm
Binary file not shown.
2 changes: 1 addition & 1 deletion zk/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[package]
name = "gradient-boosting-prover"
name = "ml-prover"
version = "0.1.0"
edition = "2021"

Expand Down
13 changes: 8 additions & 5 deletions zk/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use ff::PrimeField;
// use ff::PrimeField;
use rayon::iter::{IntoParallelIterator, ParallelIterator};
use std::{path::PathBuf, time::Instant};
use tracing::{debug, info};
use tracing::info;
use zk_engine::{
args::{WASMArgsBuilder, WASMCtx},
nova::{
Expand Down Expand Up @@ -51,7 +51,8 @@ fn main() -> anyhow::Result<()> {
//
// Here we are configuring the path to the WASM file
let args = WASMArgsBuilder::default()
.file_path(PathBuf::from("wasm/gradient_boosting.wasm"))
// .file_path(PathBuf::from("/home/vk/src/zkhack/montreal/zkml-montreal/target/wasm32-wasip1/release/ml.wasm"))
.file_path(PathBuf::from("../wasm/gradient_boosting.wasm"))
.invoke(Some(String::from("_start")))
.trace_slice_values(TraceSliceValues::new(CHUNK_SIZE * i, CHUNK_SIZE * (i + 1)))
.build();
Expand All @@ -61,11 +62,13 @@ fn main() -> anyhow::Result<()> {
// Create a WASM execution context for proving.
let mut wasm_ctx = WASMCtx::new_from_file(args).map_err(|e| e.to_string())?;

let (proof, public_values) =
let (proof, public_values, wasm_func_res) =
BatchedZKEProof::<E1, BS1<E1>, S1<E1>, S2<E1>>::prove_wasm(&mut wasm_ctx)
.map_err(|e| e.to_string())?;

let zi = public_values.execution().public_outputs();
info!("wasm result {:?}", wasm_func_res);

// let zi = public_values.execution().public_outputs();

let task_end_time = Instant::now();
let elapsed_time = (task_end_time - task_start_time).as_secs();
Expand Down

0 comments on commit e6aed0a

Please sign in to comment.