From d0cc1b1c11aea358f18af2ec32881979127fd470 Mon Sep 17 00:00:00 2001 From: kit Date: Fri, 6 Sep 2024 23:53:31 -0400 Subject: [PATCH] added fixes for model generation --- main.py | 85 +++++++++++++++++++++++++++++++++++++------------------- model.py | 4 +-- setup.py | 5 ++++ 3 files changed, 64 insertions(+), 30 deletions(-) diff --git a/main.py b/main.py index f0ed788..42c6c32 100755 --- a/main.py +++ b/main.py @@ -11,43 +11,72 @@ from Bio import SeqIO from Bio.Seq import Seq from features import extract_features -from model import load_model, predict_new_data +from model import load_model, predict_new_data, train_model_cv, save_model +from data_pipeline import load_and_preprocess_data, augment_data import numpy as np +def generate_model(positive_file: str, negative_file: str): + """Generate and save the protein interaction model.""" + print("Loading and preprocessing data...") + X, y, feature_names = load_and_preprocess_data(positive_file, negative_file) + + print("Augmenting data...") + X_augmented, y_augmented = augment_data(X, y) + + print("Training model...") + model, scaler = train_model_cv(X_augmented, y_augmented) + + print("Saving model...") + save_model(model, scaler, "protein_interaction_model.joblib") + + print("Model generation complete. Saved as 'protein_interaction_model.joblib'") + def main(): - parser = argparse.ArgumentParser(description="Predict protein interactions.") - parser.add_argument("sequence_file", help="Path to the protein sequence file (FASTA format)") + parser = argparse.ArgumentParser(description="Predict protein interactions or generate model.") + parser.add_argument("--generate", action="store_true", help="Generate new model") + parser.add_argument("--positive", help="Path to positive interaction FASTA file (for model generation)") + parser.add_argument("--negative", help="Path to negative interaction FASTA file (for model generation)") + parser.add_argument("sequence_file", nargs="?", help="Path to the protein sequence file (FASTA format) for prediction") args = parser.parse_args() - try: - # Load the sequence - with open(args.sequence_file, "r") as handle: - record = next(SeqIO.parse(handle, "fasta")) - sequence = record.seq - except FileNotFoundError: - print(f"Error: File '{args.sequence_file}' not found.") - sys.exit(1) - except StopIteration: - print(f"Error: No sequences found in '{args.sequence_file}'.") - sys.exit(1) + if args.generate: + if not args.positive or not args.negative: + print("Error: Both --positive and --negative files are required for model generation.") + sys.exit(1) + generate_model(args.positive, args.negative) + elif args.sequence_file: + try: + # Load the sequence + with open(args.sequence_file, "r") as handle: + record = next(SeqIO.parse(handle, "fasta")) + sequence = record.seq + except FileNotFoundError: + print(f"Error: File '{args.sequence_file}' not found.") + sys.exit(1) + except StopIteration: + print(f"Error: No sequences found in '{args.sequence_file}'.") + sys.exit(1) - # Extract features - features = extract_features(sequence) - X_new = np.array([list(features.values())]) + # Extract features + features = extract_features(sequence) + X_new = np.array([list(features.values())]) - # Load the pre-trained model - try: - model, scaler = load_model("protein_interaction_model.joblib") - except FileNotFoundError: - print("Error: Pre-trained model not found. Please ensure 'protein_interaction_model.joblib' is in the current directory.") - sys.exit(1) + # Load the pre-trained model + try: + model, scaler = load_model("protein_interaction_model.joblib") + except FileNotFoundError: + print("Error: Pre-trained model not found. Please ensure 'protein_interaction_model.joblib' is in the current directory.") + sys.exit(1) - # Make prediction - prediction = predict_new_data(model, scaler, X_new) + # Make prediction + prediction = predict_new_data(model, scaler, X_new) - # Print result - result = "likely to interact" if prediction[0] == 1 else "unlikely to interact" - print(f"The protein sequence in '{args.sequence_file}' is {result}.") + # Print result + result = "likely to interact" if prediction[0] == 1 else "unlikely to interact" + print(f"The protein sequence in '{args.sequence_file}' is {result}.") + else: + print("Error: Please provide either --generate with --positive and --negative files, or a sequence file for prediction.") + sys.exit(1) if __name__ == "__main__": main() \ No newline at end of file diff --git a/model.py b/model.py index 1956ae8..1308c71 100644 --- a/model.py +++ b/model.py @@ -1,6 +1,6 @@ import numpy as np -from typing import Tuple, Dict -from sklearn.model_selection import StratifiedKFold, GridSearchCV +from typing import Tuple, Dict, Any +from sklearn.model_selection import StratifiedKFold, cross_val_score from sklearn.preprocessing import StandardScaler from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score from xgboost import XGBClassifier diff --git a/setup.py b/setup.py index 21cbe77..8a7cf67 100644 --- a/setup.py +++ b/setup.py @@ -10,6 +10,11 @@ "scikit-learn", "xgboost", "joblib", + "pandas", + "matplotlib", + "seaborn", + "imbalanced-learn", + "optuna", ], entry_points={ "console_scripts": [