Skip to content

Commit

Permalink
gh error, now really: added saving of df with full results, needed fo…
Browse files Browse the repository at this point in the history
…r bookkeeping but also for further investigations into characteristics of diseases easily found (or not) by LLM (#39)

It's really a small difference which should be in here before any further runs are carried out, so I am merging this now
  • Loading branch information
leokim-l authored Jul 24, 2024
1 parent 0cd2c60 commit b23dd69
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 8 deletions.
16 changes: 9 additions & 7 deletions src/malco/post_process/compute_mrr.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ def mondo_adapter() -> OboGraphInterface:
"""
return get_adapter("sqlite:obo:mondo")

def compute_mrr(output_dir, prompt_dir, correct_answer_file) -> Path:
def compute_mrr(output_dir, prompt_dir, correct_answer_file,
raw_results_dir) -> Path:
# Read in results TSVs from self.output_dir that match glob results*tsv
results_data = []
results_files = []
Expand Down Expand Up @@ -83,6 +84,10 @@ def compute_mrr(output_dir, prompt_dir, correct_answer_file) -> Path:
lambda row: 1 / row["rank"] if row["is_correct"] else 0, axis=1
)

# Save full data frame
full_df_file = raw_results_dir / results_files[i][0:2] / "full_df_results.tsv"
df.to_csv(full_df_file, sep='\t', index=False)

# Calculate MRR for this file
mrr = df.groupby("label")["reciprocal_rank"].max().mean()
mrr_scores.append(mrr)
Expand All @@ -109,7 +114,6 @@ def compute_mrr(output_dir, prompt_dir, correct_answer_file) -> Path:
else:
# increase n10p
rank_df.loc[i,"n10p"] += 1



# Write cache charatcteristics to file
Expand All @@ -122,15 +126,13 @@ def compute_mrr(output_dir, prompt_dir, correct_answer_file) -> Path:
i = i + 1


topn_file = output_dir / "plots/topn_result.tsv"
# use rank_df.to_csv() or something similar
plot_dir = output_dir / "plots"
plot_dir.mkdir(exist_ok=True)
topn_file = plot_dir / "topn_result.tsv"
rank_df.to_csv(topn_file, sep='\t', index=False)


print("MRR scores are:\n")
print(mrr_scores)
plot_dir = output_dir / "plots"
plot_dir.mkdir(exist_ok=True)
plot_data_file = plot_dir / "plotting_data.tsv"

# write out results for plotting
Expand Down
3 changes: 2 additions & 1 deletion src/malco/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ def post_process(self,
plot_data_file, plot_dir, num_ppkt, topn_file = compute_mrr(
output_dir=self.output_dir,
prompt_dir=os.path.join(self.input_dir, prompts_subdir_name),
correct_answer_file=correct_answer_file)
correct_answer_file=correct_answer_file,
raw_results_dir=self.raw_results_dir)

if print_plot:
make_plots(plot_data_file, plot_dir, self.languages, num_ppkt, topn_file)

0 comments on commit b23dd69

Please sign in to comment.