-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Rewrite constant_fold_pass using dataflow framework #1603
base: acl/const_fold2
Are you sure you want to change the base?
Conversation
// Note we COULD filter out (avoid breaking) wires from other nodes that we are keeping. | ||
// This would insert fewer constants, but potentially expose less parallelism. | ||
.filter_map(|(n, ip)| { | ||
let (src, outp) = hugr.single_linked_output(n, ip).unwrap(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let (src, outp) = hugr.single_linked_output(n, ip).unwrap(); | |
let (src, outp) = hugr.single_linked_output(n, ip)?; |
I had to make this change locally after merging main, to fix failing tests. I believe inputs are often disconnected when working with various subgraph views.
let sig = op.signature(); | ||
let known_ins = sig | ||
.input_types() | ||
.iter() | ||
.enumerate() | ||
.zip(ins.iter()) | ||
.filter_map(|((i, ty), pv)| { | ||
Some((IncomingPort::from(i), pv.clone().try_into_value(ty).ok()?)) | ||
}) | ||
.collect::<Vec<_>>(); | ||
for (p, v) in op.constant_fold(&known_ins).unwrap_or_default() { | ||
outs[p.index()] = | ||
partial_from_const(self, ConstLocation::Field(p.index(), &node.into()), &v); | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let sig = op.signature(); | |
let known_ins = sig | |
.input_types() | |
.iter() | |
.enumerate() | |
.zip(ins.iter()) | |
.filter_map(|((i, ty), pv)| { | |
Some((IncomingPort::from(i), pv.clone().try_into_value(ty).ok()?)) | |
}) | |
.collect::<Vec<_>>(); | |
for (p, v) in op.constant_fold(&known_ins).unwrap_or_default() { | |
outs[p.index()] = | |
partial_from_const(self, ConstLocation::Field(p.index(), &node.into()), &v); | |
} | |
} | |
// If any inputs are bottom this node is currently unreachable. | |
// We should not propagate anything to our outputs here. | |
// TODO perhaps this should be done by the caller and we should have a | |
// precondition here that no ins are Bottom; this if statement | |
// would become a debug_assert!. | |
if ins.iter().any(|v| *v == PartialValue::Bottom) { | |
return; | |
} | |
let sig = op.signature(); | |
let known_ins = sig | |
.input_types() | |
.iter() | |
.enumerate() | |
.zip(ins.iter()) | |
.filter_map(|((i, ty), pv)| { | |
Some((IncomingPort::from(i), pv.clone().try_into_value(ty).ok()?)) | |
}) | |
.collect::<Vec<_>>(); | |
// Collect the results of the fold. Any output wires that are not | |
// returned by `constant_fold` are set to Top. | |
let fold_outs: HashMap<OutgoingPort, Value> = op | |
.constant_fold(&known_ins) | |
.into_iter() | |
.flat_map(Vec::into_iter) | |
.collect(); | |
for (p, o) in outs.iter_mut().enumerate() { | |
o.join_mut( | |
fold_outs | |
.get(&OutgoingPort::from(p)) | |
.map_or(PartialValue::Top, |v| { | |
partial_from_const(self, node, &mut vec![p.index()], v) | |
}), | |
); | |
} |
Remnants of #1157 after the "framework" of #1476 is pulled out, plus transformation roughly paralleling the old constant-folding code + dead-code-elimination. (Eliminating cases from conditionals/tail-loops and transforming into DFGs, left for another PR.)
All the existing constant-folding tests pass! 😀 🚀, plus at least one that didn't just fail-to-optimize, but panicked - or, in release mode, potentially produced an invalid hugr - see comment.
Hence, feels like it's time this had some eyes, also as guidance for reviewers of #1476.
closes #1322