Skip to content

Commit

Permalink
Track rewrite stats in report (egg only)
Browse files Browse the repository at this point in the history
  • Loading branch information
marcusrossel committed Nov 12, 2024
1 parent ee40d5d commit c363a09
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 35 deletions.
7 changes: 5 additions & 2 deletions C/ffi.c
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ structure Report where
nodeCount: Nat
classCount: Nat
time: Float
rwStats: String
*/

typedef struct report {
Expand All @@ -215,15 +216,17 @@ typedef struct report {
size_t node_count;
size_t class_count;
double time;
char* rw_stats;
} report;

lean_obj_res report_to_lean(report rep) {
lean_object* r = lean_alloc_ctor(0, 3, sizeof(double) + sizeof(uint8_t));
size_t obj_offset = 3 * sizeof(void*);
lean_object* r = lean_alloc_ctor(0, 4, sizeof(double) + sizeof(uint8_t));
size_t obj_offset = 4 * sizeof(void*);

lean_ctor_set(r, 0, lean_box(rep.iterations));
lean_ctor_set(r, 1, lean_box(rep.node_count));
lean_ctor_set(r, 2, lean_box(rep.class_count));
lean_ctor_set(r, 3, lean_mk_string(rep.rw_stats));
lean_ctor_set_float(r, obj_offset, rep.time);
lean_ctor_set_uint8(r, obj_offset + sizeof(double), stop_reason_to_lean(rep.reason));

Expand Down
1 change: 1 addition & 0 deletions Lean/Egg/Core/Request/Basic.lean
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ structure Result.Report where
nodeCount: Nat
classCount: Nat
time: Float
rwStats: String

-- IMPORTANT: The C interface to egg depends on the order of these fields.
private structure Result.Raw where
Expand Down
3 changes: 2 additions & 1 deletion Lean/Egg/Tactic/Trace.lean
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ nonrec def formatReport
"nodes: " ++ (format rep.nodeCount) ++ "\n" ++
"classes: " ++ (format rep.classCount) ++ "\n" ++
(if let some e := expl? then "expl steps: " ++ format e.steps.size ++ s!"\nbinder rws: {e.involvesBinderRewrites}\n" else "") ++
s!"⊢ binders: {goalContainsBinder}"
s!"⊢ binders: {goalContainsBinder}" ++
(if rep.rwStats.isEmpty then "" else s!"\nrw stats:\n{rep.rwStats}")

nonrec def MVars.toMessageData (mvars : MVars) : MetaM MessageData := do
let expr := format <| ← mvars.expr.toList.mapM (ppExpr <| Expr.mvar ·)
Expand Down
44 changes: 40 additions & 4 deletions Rust/Egg/src/basic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::result::*;
use crate::analysis::*;
use crate::beta::*;
use crate::eta::*;
use crate::lean_expr::*;
use crate::levels::*;
use crate::nat_lit::*;
use crate::rewrite::*;
Expand All @@ -29,7 +30,14 @@ pub struct Config {
allow_unsat_conditions: bool
}

pub fn explain_congr(init: String, goal: String, rw_templates: Vec<RewriteTemplate>, facts: Vec<(String, String)>, guides: Vec<String>, cfg: Config, viz_path: Option<String>) -> Result<(String, LeanEGraph, Report), Error> {
pub struct ExplainedCongr {
pub expl: String,
pub egraph: LeanEGraph,
pub report: Report,
pub rw_stats: String
}

pub fn explain_congr(init: String, goal: String, rw_templates: Vec<RewriteTemplate>, facts: Vec<(String, String)>, guides: Vec<String>, cfg: Config, viz_path: Option<String>) -> Result<ExplainedCongr, Error> {
let analysis = LeanAnalysis { union_semantics: cfg.union_semantics };
let mut egraph: LeanEGraph = EGraph::new(analysis);
egraph = egraph.with_explanations_enabled();
Expand Down Expand Up @@ -62,7 +70,8 @@ pub fn explain_congr(init: String, goal: String, rw_templates: Vec<RewriteTempla
if cfg.eta_expand { rws.push(eta_expansion_rw()) }
if cfg.beta { rws.push(beta_reduction_rw()) }
if cfg.levels { rws.append(&mut level_rws()) }
// TODO: Only add these rws if on of the following is active: beta, eta, bvar index correction. Anything else?
// TODO: Only add these rws if one of the following is active: beta, eta, eta-expansion,
// bvar index correction. Anything else?
rws.append(&mut subst_rws());
rws.append(&mut shift_rws());

Expand All @@ -84,12 +93,39 @@ pub fn explain_congr(init: String, goal: String, rw_templates: Vec<RewriteTempla
.run(&rws);

let report = runner.report();
let rw_stats = collect_rw_stats(&runner);

if runner.egraph.find(init_id) == runner.egraph.find(goal_id) {
let mut expl = runner.explain_equivalence(&init_expr, &goal_expr);
let expl_str = expl.get_flat_string();
Ok((expl_str, runner.egraph, report))
Ok(ExplainedCongr { expl: expl_str, egraph: runner.egraph, report, rw_stats })
} else {
Ok(("".to_string(), runner.egraph, report))
Ok(ExplainedCongr { expl: "".to_string(), egraph: runner.egraph, report, rw_stats })
}
}

fn collect_rw_stats(runner: &Runner<LeanExpr, LeanAnalysis>) -> String {
let mut stats: HashMap<String, usize> = Default::default();
let mut longest_rw: usize = 0;

for iter in &runner.iterations {
for (rw, count) in &iter.applied {
let rw_str = rw.to_string();
let normal_rw = rw_str.strip_suffix("-rev").unwrap_or(&rw_str);
longest_rw = longest_rw.max(normal_rw.chars().count());

let current = stats.get(normal_rw).unwrap_or(&0);
stats.insert(normal_rw.to_string(), current + count);
}
}

let mut entries: Vec<_> = stats.iter().collect();
entries.sort_by(|l, r| l.0.cmp(r.0));

entries.iter().map(|e| {
let padding = 1 + longest_rw - e.0.chars().count();
format!("{}:{}{}", e.0, " ".repeat(padding), e.1)
})
.collect::<Vec<_>>()
.join("\n")
}
37 changes: 19 additions & 18 deletions Rust/Egg/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ use egg::*;
use core::ffi::c_char;
use core::ffi::CStr;
use std::ffi::CString;
use std::ptr::null;
use libc::c_double;
use std::str::FromStr;
use basic::*;
Expand Down Expand Up @@ -30,6 +31,12 @@ fn c_str_to_string(c_str: *const c_char) -> String {
String::from_utf8_lossy(str.to_bytes()).to_string()
}

// TODO: I think this is a memory leak right now.
fn string_to_c_str(str: String) -> *const c_char {
let expl_c_str = CString::new(str).expect("conversion of Rust-string to C-string failed");
expl_c_str.into_raw()
}

#[repr(C)]
pub struct CStringArray {
ptr: *const *const c_char,
Expand Down Expand Up @@ -155,17 +162,19 @@ pub struct CReport {
egraph_nodes: usize,
egraph_classes: usize,
total_time: c_double,
rw_stats: *const c_char,
}

impl CReport {

fn from_report(r: Report) -> CReport {
fn from_report(r: Report, rw_stats: String) -> CReport {
CReport {
iterations: r.iterations,
stop_reason: CStopReason::from_stop_reason(r.stop_reason),
egraph_nodes: r.egraph_nodes,
egraph_classes: r.egraph_classes,
total_time: r.total_time,
rw_stats: string_to_c_str(rw_stats),
}
}

Expand All @@ -176,6 +185,7 @@ impl CReport {
egraph_nodes: 0,
egraph_classes: 0,
total_time: 0.0,
rw_stats: null(),
}
}
}
Expand All @@ -202,33 +212,25 @@ pub extern "C" fn egg_explain_congr(
let guides = guides.to_vec();
let facts = facts.to_vec();

// Note: The `into_raw`s below are important, as otherwise Rust deallocates the string.
// TODO: I think this is a memory leak right now.

let rw_templates = rws.to_templates();
if let Err(rws_err) = rw_templates {
let rws_err_c_str = CString::new(rws_err.to_string()).expect("conversion of error message to C-string failed");
return EqsatResult { expl: rws_err_c_str.into_raw(), graph: None, report: CReport::none() }
return EqsatResult { expl: string_to_c_str(rws_err.to_string()), graph: None, report: CReport::none() }
}
let rw_templates = rw_templates.unwrap();

let viz_path_c_str = unsafe { CStr::from_ptr(viz_path_ptr) };
let raw_viz_path = String::from_utf8_lossy(viz_path_c_str.to_bytes()).to_string();
let raw_viz_path = c_str_to_string(viz_path_ptr);
let viz_path = if raw_viz_path.is_empty() { None } else { Some(raw_viz_path) };

let res = explain_congr(init, goal, rw_templates, facts, guides, cfg, viz_path);
if let Err(res_err) = res {
let res_err_c_str = CString::new(res_err.to_string()).expect("conversion of error message to C-string failed");
return EqsatResult { expl: res_err_c_str.into_raw(), graph: None, report: CReport::none() }
return EqsatResult { expl: string_to_c_str(res_err.to_string()), graph: None, report: CReport::none() }
}
let (expl, egraph, report) = res.unwrap();

let expl_c_str = CString::new(expl).expect("conversion of explanation to C-string failed");
let ExplainedCongr { expl, egraph, report, rw_stats } = res.unwrap();

return EqsatResult {
expl: expl_c_str.into_raw(),
expl: string_to_c_str(expl),
graph: Some(Box::new(egraph)),
report: CReport::from_report(report)
report: CReport::from_report(report, rw_stats)
}
}

Expand All @@ -247,10 +249,9 @@ pub unsafe extern "C" fn egg_query_equiv(
if egraph.find(init_id) == egraph.find(goal_id) {
let mut expl = egraph.explain_equivalence(&init, &goal);
let expl_str = expl.get_flat_string();
let expl_c_str = CString::new(expl_str.to_string()).expect("conversion of explanation to C-string failed");
expl_c_str.into_raw()
string_to_c_str(expl_str)
} else {
CString::new("").unwrap().into_raw()
string_to_c_str("".to_string())
}
}

Expand Down
25 changes: 15 additions & 10 deletions Rust/Slotted/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use slotted_egraphs::*;
use core::ffi::c_char;
use core::ffi::CStr;
use std::ffi::CString;
use std::ptr::null;
use libc::c_double;
use basic::*;
use analysis::*;
Expand All @@ -24,6 +25,12 @@ fn c_str_to_string(c_str: *const c_char) -> String {
String::from_utf8_lossy(str.to_bytes()).to_string()
}

// TODO: I think this is a memory leak right now.
fn string_to_c_str(str: String) -> *const c_char {
let expl_c_str = CString::new(str).expect("conversion of Rust-string to C-string failed");
expl_c_str.into_raw()
}

#[repr(C)]
pub struct CStringArray {
ptr: *const *const c_char,
Expand Down Expand Up @@ -148,6 +155,7 @@ pub struct CReport {
egraph_nodes: usize,
egraph_classes: usize,
total_time: c_double,
rw_stats: *const c_char,
}

impl CReport {
Expand All @@ -159,6 +167,7 @@ impl CReport {
egraph_nodes: r.egraph_nodes,
egraph_classes: r.egraph_classes,
total_time: r.total_time,
rw_stats: string_to_c_str("".to_string()),
}
}

Expand All @@ -169,6 +178,7 @@ impl CReport {
egraph_nodes: 0,
egraph_classes: 0,
total_time: 0.0,
rw_stats: null(),
}
}
}
Expand Down Expand Up @@ -200,8 +210,7 @@ pub extern "C" fn slotted_explain_congr(

let rw_templates = rws.to_templates();
if let Err(rws_err) = rw_templates {
let rws_err_c_str = CString::new(rws_err.to_string()).expect("conversion of error message to C-string failed");
return EqsatResult { expl: rws_err_c_str.into_raw(), graph: None, report: CReport::none() }
return EqsatResult { expl: string_to_c_str(rws_err.to_string()), graph: None, report: CReport::none() }
}
let rw_templates = rw_templates.unwrap();

Expand All @@ -211,15 +220,12 @@ pub extern "C" fn slotted_explain_congr(

let res = explain_congr(init, goal, rw_templates, facts, guides, cfg, viz_path);
if let Err(res_err) = res {
let res_err_c_str = CString::new(res_err.to_string()).expect("conversion of error message to C-string failed");
return EqsatResult { expl: res_err_c_str.into_raw(), graph: None, report: CReport::none() }
return EqsatResult { expl: string_to_c_str(res_err.to_string()), graph: None, report: CReport::none() }
}
let (expl, egraph, report) = res.unwrap();

let expl_c_str = CString::new(expl).expect("conversion of explanation to C-string failed");

return EqsatResult {
expl: expl_c_str.into_raw(),
expl: string_to_c_str(expl),
graph: Some(Box::new(egraph)),
report: CReport::from_report(report)
}
Expand All @@ -239,10 +245,9 @@ pub unsafe extern "C" fn slotted_query_equiv(

if egraph.eq(&init_id, &goal_id) {
let expl = egraph.explain_equivalence(init, goal).to_flat_string(&egraph);
let expl_c_str = CString::new(expl).expect("conversion of explanation to C-string failed");
expl_c_str.into_raw()
string_to_c_str(expl)
} else {
CString::new("").unwrap().into_raw()
string_to_c_str("".to_string())
}
}

Expand Down

0 comments on commit c363a09

Please sign in to comment.