Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add psi server logic #30

Merged
merged 14 commits into from
Dec 3, 2024
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ rabbitmq = ["lapin"]
[dev-dependencies]
approx = "0.5.1"
http-body-util = "0.1.2"
ndarray-rand = "0.14.0"
mockito = "1.4.0"
sqlx-cli = { version = "0.7.4", default-features = false, features = ["native-tls", "postgres"] }
tower = { version = "0.4.13", features = ["util"] }
Expand Down
145 changes: 143 additions & 2 deletions src/alerts/psi/drift.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ impl PsiDrifter {
))
}

async fn get_drift_map(
pub async fn get_drift_map(
&self,
limit_timestamp: &NaiveDateTime,
db_client: &PostgresClient,
Expand Down Expand Up @@ -132,7 +132,7 @@ impl PsiDrifter {
filtered_drift_map
}

async fn generate_alerts(
pub async fn generate_alerts(
&self,
drift_map: &HashMap<String, f64>,
) -> Result<Option<HashMap<String, f64>>> {
Expand Down Expand Up @@ -246,3 +246,144 @@ impl PsiDrifter {
}
}
}

#[cfg(test)]
mod tests {
use super::*;
use ndarray::Array;
use ndarray_rand::rand_distr::Uniform;
use ndarray_rand::RandomExt;
use scouter::core::drift::psi::types::{
Bin, PsiAlertConfig, PsiDriftConfig, PsiFeatureDriftProfile,
};

fn get_test_drifter() -> PsiDrifter {
let config = PsiDriftConfig::new(
Some("name".to_string()),
Some("repo".to_string()),
None,
None,
None,
Some(PsiAlertConfig::new(
None,
None,
Some(vec!["feature_1".to_string(), "feature_3".to_string()]),
None,
None,
)),
None,
)
.unwrap();

let array = Array::random((1030, 3), Uniform::new(1.0, 100.0));

let features = vec![
"feature_1".to_string(),
"feature_2".to_string(),
"feature_3".to_string(),
];

let monitor = PsiMonitor::new();

let profile = monitor
.create_2d_drift_profile(&features, &array.view(), &config)
.unwrap();

PsiDrifter::new(profile)
}

#[test]
fn test_get_monitored_profiles() {
let drifter = get_test_drifter();

let profiles_to_monitor = drifter.get_monitored_profiles();

assert_eq!(profiles_to_monitor.len(), 2);

assert!(
profiles_to_monitor[0].id == "feature_1" || profiles_to_monitor[0].id == "feature_3"
);
assert!(
profiles_to_monitor[1].id == "feature_1" || profiles_to_monitor[1].id == "feature_3"
);
}

#[test]
fn test_get_feature_bin_proportion_pairs() {
let training_feat1_decile1_prop = 0.2;
let training_feat1_decile2_prop = 0.5;
let training_feat1_decile3_prop = 0.3;

let feature_drift_profile = PsiFeatureDriftProfile {
id: "feature_1".to_string(),
bins: vec![
Bin {
id: "decile_1".to_string(),
lower_limit: Some(0.1),
upper_limit: Some(0.2),
proportion: training_feat1_decile1_prop,
},
Bin {
id: "decile_2".to_string(),
lower_limit: Some(0.2),
upper_limit: Some(0.4),
proportion: training_feat1_decile2_prop,
},
Bin {
id: "decile_3".to_string(),
lower_limit: Some(0.4),
upper_limit: Some(0.8),
proportion: training_feat1_decile3_prop,
},
],
timestamp: Default::default(),
};

let mut observed_bin_proportions = HashMap::new();
let mut bin_map = HashMap::new();
let observed_feat1_decile1_prop = 0.6;
let observed_feat1_decile2_prop = 0.3;
let observed_feat1_decile3_prop = 0.1;
bin_map.insert("decile_1".to_string(), observed_feat1_decile1_prop);
bin_map.insert("decile_2".to_string(), observed_feat1_decile2_prop);
bin_map.insert("decile_3".to_string(), observed_feat1_decile3_prop);
observed_bin_proportions.insert("feature_1".to_string(), bin_map);

let drifter = get_test_drifter();

let proportion_pairs = drifter
.get_feature_bin_proportion_pairs(&feature_drift_profile, &observed_bin_proportions);

proportion_pairs.iter().for_each(|(a, b)| {
if *a == training_feat1_decile1_prop {
assert_eq!(*b, observed_feat1_decile1_prop);
} else if *a == training_feat1_decile2_prop {
assert_eq!(*b, observed_feat1_decile2_prop);
} else if *a == training_feat1_decile3_prop {
assert_eq!(*b, observed_feat1_decile3_prop);
} else {
panic!("test failed: proportion mismatch!");
}
})
}

#[test]
fn test_filter_drift_map() {
let drifter = get_test_drifter();

let mut drift_map = HashMap::new();

let feature_with_drift = "feature_4".to_string();

drift_map.insert("feature_1".to_string(), 0.07);
drift_map.insert("feature_2".to_string(), 0.2);
drift_map.insert("feature_3".to_string(), 0.23);
drift_map.insert(feature_with_drift.clone(), 0.3);
drift_map.insert("feature_5".to_string(), 0.12);

// we did not specify a custom psi threshold and thus will be using the default of 0.25
let filtered_drift_map = drifter.filter_drift_map(&drift_map);
assert_eq!(filtered_drift_map.len(), 1);
assert!(filtered_drift_map.contains_key(&feature_with_drift));
}
}
53 changes: 52 additions & 1 deletion tests/drift_integration.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use chrono::NaiveDateTime;
use scouter::core::dispatch::dispatcher::dispatcher_logic::{ConsoleAlertDispatcher, Dispatch};

use scouter::core::drift::psi::types::PsiDriftProfile;
use scouter::core::drift::spc::types::SpcDriftProfile;
use scouter_server::alerts::base::DriftExecutor;
use scouter_server::alerts::psi::drift::PsiDrifter;
use scouter_server::alerts::spc::drift::SpcDrifter;
use scouter_server::sql::postgres::PostgresClient;
use sqlx::{Postgres, Row};
Expand Down Expand Up @@ -66,6 +67,56 @@ async fn test_drift_executor_separate() {
test_utils::teardown().await.unwrap();
}

#[tokio::test]
async fn test_drift_executor_psi_tasks() {
let pool = test_utils::setup_db(true).await.unwrap();
let db_client = PostgresClient::new(pool.clone()).unwrap();

let populate_script = include_str!("scripts/populate_psi.sql");
sqlx::raw_sql(populate_script).execute(&pool).await.unwrap();

let mut transaction: sqlx::Transaction<Postgres> = db_client.pool.begin().await.unwrap();
let profile = PostgresClient::get_drift_profile_task(&mut transaction)
.await
.unwrap();

assert!(profile.is_some());

let profile = profile.unwrap();
let drift_profile: PsiDriftProfile = serde_json::from_str(&profile.profile).unwrap();

assert_eq!(drift_profile.config.name, "model");
assert_eq!(drift_profile.config.repository, "scouter");

let previous_run: NaiveDateTime = profile.previous_run;

let drifter = PsiDrifter::new(drift_profile.clone());

let mut drift_map = drifter
.get_drift_map(&previous_run, &db_client)
.await
.unwrap()
.unwrap();

assert_eq!(drift_map.len(), 3);

if let Some(value) = drift_map.get_mut("feature_1") {
*value = 0.3
}

let alerts = drifter.generate_alerts(&drift_map).await.unwrap().unwrap();

assert_eq!(alerts.len(), 1);

let (feature_name, drift_value) = alerts.iter().next().unwrap();

assert_eq!(feature_name, "feature_1");
assert!((drift_value - 0.3).abs() < 1e-9);

transaction.commit().await.unwrap();
test_utils::teardown().await.unwrap();
}

#[tokio::test]
async fn test_drift_executor() {
let pool = test_utils::setup_db(true).await.unwrap();
Expand Down
Loading
Loading