From bcbd23439b925df3a91bec10ba1b425dd72bd0f9 Mon Sep 17 00:00:00 2001 From: kit Date: Sat, 7 Sep 2024 23:49:21 -0400 Subject: [PATCH] added a plot to visualize model's learning curve --- main.py | 7 ++++++- visualization.py | 41 ++++++++++++++++++++++++++++++++++++++--- 2 files changed, 44 insertions(+), 4 deletions(-) diff --git a/main.py b/main.py index 42c6c32..c4bd8a0 100755 --- a/main.py +++ b/main.py @@ -3,7 +3,7 @@ from model import train_model_cv, evaluate_model, save_model, compare_models, optimize_hyperparameters from interpretation import explain_model from ensemble import create_ensemble -from visualization import plot_feature_importance, plot_confusion_matrix, plot_correlation_matrix +from visualization import plot_feature_importance, plot_confusion_matrix, plot_correlation_matrix, visualize_results #!/usr/bin/env python3 import argparse @@ -26,10 +26,15 @@ def generate_model(positive_file: str, negative_file: str): print("Training model...") model, scaler = train_model_cv(X_augmented, y_augmented) + print("Generating visualization plots...") + y_pred = predict_new_data(model, scaler, X_augmented) + visualize_results(model, X_augmented, y_augmented, y_pred, feature_names) + print("Saving model...") save_model(model, scaler, "protein_interaction_model.joblib") print("Model generation complete. Saved as 'protein_interaction_model.joblib'") + print("Visualization plots have been generated and saved.") def main(): parser = argparse.ArgumentParser(description="Predict protein interactions or generate model.") diff --git a/visualization.py b/visualization.py index 6331c83..416ef0e 100644 --- a/visualization.py +++ b/visualization.py @@ -3,8 +3,8 @@ import pandas as pd import numpy as np from xgboost import XGBClassifier -from sklearn.metrics import confusion_matrix -from typing import List +from sklearn.metrics import confusion_matrix, learning_curve +from typing import List, Tuple def plot_feature_importance(model: XGBClassifier, feature_names: List[str]) -> None: """Plot feature importance using seaborn.""" @@ -39,4 +39,39 @@ def plot_correlation_matrix(X: np.ndarray, feature_names: List[str]) -> None: plt.title('Feature Correlation Matrix') plt.tight_layout() plt.savefig("correlation_matrix.png") - plt.close() \ No newline at end of file + plt.close() + +def plot_learning_curve(estimator: XGBClassifier, X: np.ndarray, y: np.ndarray) -> None: + """Plot learning curve for the model.""" + train_sizes, train_scores, test_scores = learning_curve( + estimator, X, y, cv=5, n_jobs=-1, + train_sizes=np.linspace(0.1, 1.0, 10), scoring="roc_auc" + ) + + train_mean = np.mean(train_scores, axis=1) + train_std = np.std(train_scores, axis=1) + test_mean = np.mean(test_scores, axis=1) + test_std = np.std(test_scores, axis=1) + + plt.figure(figsize=(10, 6)) + plt.plot(train_sizes, train_mean, label="Training score", color="blue") + plt.fill_between(train_sizes, train_mean - train_std, train_mean + train_std, alpha=0.1, color="blue") + plt.plot(train_sizes, test_mean, label="Cross-validation score", color="red") + plt.fill_between(train_sizes, test_mean - test_std, test_mean + test_std, alpha=0.1, color="red") + + plt.title("Learning Curve") + plt.xlabel("Training Examples") + plt.ylabel("ROC AUC Score") + plt.legend(loc="best") + plt.grid(True) + plt.tight_layout() + plt.savefig("learning_curve.png") + plt.close() + +def visualize_results(model: XGBClassifier, X: np.ndarray, y: np.ndarray, y_pred: np.ndarray, feature_names: List[str]) -> None: + """Generate and save all visualization plots.""" + plot_feature_importance(model, feature_names) + plot_confusion_matrix(y, y_pred) + plot_correlation_matrix(X, feature_names) + plot_learning_curve(model, X, y) + print("All visualization plots have been generated and saved.") \ No newline at end of file