Skip to content

Commit

Permalink
added a plot to visualize model's learning curve
Browse files Browse the repository at this point in the history
  • Loading branch information
yayekit committed Sep 8, 2024
1 parent d0cc1b1 commit bcbd234
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 4 deletions.
7 changes: 6 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from model import train_model_cv, evaluate_model, save_model, compare_models, optimize_hyperparameters
from interpretation import explain_model
from ensemble import create_ensemble
from visualization import plot_feature_importance, plot_confusion_matrix, plot_correlation_matrix
from visualization import plot_feature_importance, plot_confusion_matrix, plot_correlation_matrix, visualize_results

#!/usr/bin/env python3
import argparse
Expand All @@ -26,10 +26,15 @@ def generate_model(positive_file: str, negative_file: str):
print("Training model...")
model, scaler = train_model_cv(X_augmented, y_augmented)

print("Generating visualization plots...")
y_pred = predict_new_data(model, scaler, X_augmented)
visualize_results(model, X_augmented, y_augmented, y_pred, feature_names)

print("Saving model...")
save_model(model, scaler, "protein_interaction_model.joblib")

print("Model generation complete. Saved as 'protein_interaction_model.joblib'")
print("Visualization plots have been generated and saved.")

def main():
parser = argparse.ArgumentParser(description="Predict protein interactions or generate model.")
Expand Down
41 changes: 38 additions & 3 deletions visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import pandas as pd
import numpy as np
from xgboost import XGBClassifier
from sklearn.metrics import confusion_matrix
from typing import List
from sklearn.metrics import confusion_matrix, learning_curve
from typing import List, Tuple

def plot_feature_importance(model: XGBClassifier, feature_names: List[str]) -> None:
"""Plot feature importance using seaborn."""
Expand Down Expand Up @@ -39,4 +39,39 @@ def plot_correlation_matrix(X: np.ndarray, feature_names: List[str]) -> None:
plt.title('Feature Correlation Matrix')
plt.tight_layout()
plt.savefig("correlation_matrix.png")
plt.close()
plt.close()

def plot_learning_curve(estimator: XGBClassifier, X: np.ndarray, y: np.ndarray) -> None:
"""Plot learning curve for the model."""
train_sizes, train_scores, test_scores = learning_curve(
estimator, X, y, cv=5, n_jobs=-1,
train_sizes=np.linspace(0.1, 1.0, 10), scoring="roc_auc"
)

train_mean = np.mean(train_scores, axis=1)
train_std = np.std(train_scores, axis=1)
test_mean = np.mean(test_scores, axis=1)
test_std = np.std(test_scores, axis=1)

plt.figure(figsize=(10, 6))
plt.plot(train_sizes, train_mean, label="Training score", color="blue")
plt.fill_between(train_sizes, train_mean - train_std, train_mean + train_std, alpha=0.1, color="blue")
plt.plot(train_sizes, test_mean, label="Cross-validation score", color="red")
plt.fill_between(train_sizes, test_mean - test_std, test_mean + test_std, alpha=0.1, color="red")

plt.title("Learning Curve")
plt.xlabel("Training Examples")
plt.ylabel("ROC AUC Score")
plt.legend(loc="best")
plt.grid(True)
plt.tight_layout()
plt.savefig("learning_curve.png")
plt.close()

def visualize_results(model: XGBClassifier, X: np.ndarray, y: np.ndarray, y_pred: np.ndarray, feature_names: List[str]) -> None:
"""Generate and save all visualization plots."""
plot_feature_importance(model, feature_names)
plot_confusion_matrix(y, y_pred)
plot_correlation_matrix(X, feature_names)
plot_learning_curve(model, X, y)
print("All visualization plots have been generated and saved.")

0 comments on commit bcbd234

Please sign in to comment.