Skip to content

Commit

Permalink
feat: Emulate TypeBounds on parameters via constraints. (#1624)
Browse files Browse the repository at this point in the history
This PR translates some `TypeBound`s in `hugr-core` to the `nonlinear`
constraint in `hugr-model`. This translation only occurs on parameters
that take a runtime type directly.

As a driveby change before the model stabilises, this PR also moves
constraints out of the parameter lists into their own list. In its
previous form this could have led to confusions about which parameter a
local variable index refers to when a constraint is situated between two
parameters in the list. We also remove constraints from aliases for now.

Closes #1637.
  • Loading branch information
zrho authored Nov 20, 2024
1 parent 9a43956 commit 6bd094f
Show file tree
Hide file tree
Showing 13 changed files with 349 additions and 196 deletions.
85 changes: 68 additions & 17 deletions hugr-core/src/export.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::{
type_param::{TypeArgVariable, TypeParam},
type_row::TypeRowBase,
CustomType, FuncTypeBase, MaybeRV, PolyFuncTypeBase, RowVariable, SumType, TypeArg,
TypeBase, TypeEnum,
TypeBase, TypeBound, TypeEnum,
},
Direction, Hugr, HugrView, IncomingPort, Node, Port,
};
Expand Down Expand Up @@ -44,8 +44,21 @@ struct Context<'a> {
bump: &'a Bump,
/// Stores the terms that we have already seen to avoid duplicates.
term_map: FxHashMap<model::Term<'a>, model::TermId>,

/// The current scope for local variables.
///
/// This is set to the id of the smallest enclosing node that defines a polymorphic type.
/// We use this when exporting local variables in terms.
local_scope: Option<model::NodeId>,

/// Constraints to be added to the local scope.
///
/// When exporting a node that defines a polymorphic type, we use this field
/// to collect the constraints that need to be added to that polymorphic
/// type. Currently this is used to record `nonlinear` constraints on uses
/// of `TypeParam::Type` with a `TypeBound::Copyable` bound.
local_constraints: Vec<model::TermId>,

/// Mapping from extension operations to their declarations.
decl_operations: FxHashMap<(ExtensionId, OpName), model::NodeId>,
}
Expand All @@ -63,6 +76,7 @@ impl<'a> Context<'a> {
term_map: FxHashMap::default(),
local_scope: None,
decl_operations: FxHashMap::default(),
local_constraints: Vec::new(),
}
}

Expand Down Expand Up @@ -173,9 +187,11 @@ impl<'a> Context<'a> {
}

fn with_local_scope<T>(&mut self, node: model::NodeId, f: impl FnOnce(&mut Self) -> T) -> T {
let old_scope = self.local_scope.replace(node);
let prev_local_scope = self.local_scope.replace(node);
let prev_local_constraints = std::mem::take(&mut self.local_constraints);
let result = f(self);
self.local_scope = old_scope;
self.local_scope = prev_local_scope;
self.local_constraints = prev_local_constraints;
result
}

Expand Down Expand Up @@ -232,10 +248,11 @@ impl<'a> Context<'a> {

OpType::FuncDefn(func) => self.with_local_scope(node_id, |this| {
let name = this.get_func_name(node).unwrap();
let (params, signature) = this.export_poly_func_type(&func.signature);
let (params, constraints, signature) = this.export_poly_func_type(&func.signature);
let decl = this.bump.alloc(model::FuncDecl {
name,
params,
constraints,
signature,
});
let extensions = this.export_ext_set(&func.signature.body().extension_reqs);
Expand All @@ -247,10 +264,11 @@ impl<'a> Context<'a> {

OpType::FuncDecl(func) => self.with_local_scope(node_id, |this| {
let name = this.get_func_name(node).unwrap();
let (params, func) = this.export_poly_func_type(&func.signature);
let (params, constraints, func) = this.export_poly_func_type(&func.signature);
let decl = this.bump.alloc(model::FuncDecl {
name,
params,
constraints,
signature: func,
});
model::Operation::DeclareFunc { decl }
Expand Down Expand Up @@ -450,10 +468,11 @@ impl<'a> Context<'a> {

let decl = self.with_local_scope(node, |this| {
let name = this.make_qualified_name(opdef.extension(), opdef.name());
let (params, r#type) = this.export_poly_func_type(poly_func_type);
let (params, constraints, r#type) = this.export_poly_func_type(poly_func_type);
let decl = this.bump.alloc(model::OperationDecl {
name,
params,
constraints,
r#type,
});
decl
Expand Down Expand Up @@ -671,22 +690,36 @@ impl<'a> Context<'a> {
regions.into_bump_slice()
}

/// Exports a polymorphic function type.
///
/// The returned triple consists of:
/// - The static parameters of the polymorphic function type.
/// - The constraints of the polymorphic function type.
/// - The function type itself.
pub fn export_poly_func_type<RV: MaybeRV>(
&mut self,
t: &PolyFuncTypeBase<RV>,
) -> (&'a [model::Param<'a>], model::TermId) {
) -> (&'a [model::Param<'a>], &'a [model::TermId], model::TermId) {
let mut params = BumpVec::with_capacity_in(t.params().len(), self.bump);
let scope = self
.local_scope
.expect("exporting poly func type outside of local scope");

for (i, param) in t.params().iter().enumerate() {
let name = self.bump.alloc_str(&i.to_string());
let r#type = self.export_type_param(param);
let param = model::Param::Implicit { name, r#type };
let r#type = self.export_type_param(param, Some(model::LocalRef::Index(scope, i as _)));
let param = model::Param {
name,
r#type,
sort: model::ParamSort::Implicit,
};
params.push(param)
}

let constraints = self.bump.alloc_slice_copy(&self.local_constraints);
let body = self.export_func_type(t.body());

(params.into_bump_slice(), body)
(params.into_bump_slice(), constraints, body)
}

pub fn export_type<RV: MaybeRV>(&mut self, t: &TypeBase<RV>) -> model::TermId {
Expand All @@ -703,7 +736,6 @@ impl<'a> Context<'a> {
}
TypeEnum::Function(func) => self.export_func_type(func),
TypeEnum::Variable(index, _) => {
// This ignores the type bound for now
let node = self.local_scope.expect("local variable out of scope");
self.make_term(model::Term::Var(model::LocalRef::Index(node, *index as _)))
}
Expand Down Expand Up @@ -794,20 +826,39 @@ impl<'a> Context<'a> {
self.make_term(model::Term::List { items, tail: None })
}

pub fn export_type_param(&mut self, t: &TypeParam) -> model::TermId {
/// Exports a `TypeParam` to a term.
///
/// The `var` argument is set when the type parameter being exported is the
/// type of a parameter to a polymorphic definition. In that case we can
/// generate a `nonlinear` constraint for the type of runtime types marked as
/// `TypeBound::Copyable`.
pub fn export_type_param(
&mut self,
t: &TypeParam,
var: Option<model::LocalRef<'static>>,
) -> model::TermId {
match t {
// This ignores the type bound for now.
TypeParam::Type { .. } => self.make_term(model::Term::Type),
// This ignores the type bound for now.
TypeParam::Type { b } => {
if let (Some(var), TypeBound::Copyable) = (var, b) {
let term = self.make_term(model::Term::Var(var));
let non_linear = self.make_term(model::Term::NonLinearConstraint { term });
self.local_constraints.push(non_linear);
}

self.make_term(model::Term::Type)
}
// This ignores the bound on the natural for now.
TypeParam::BoundedNat { .. } => self.make_term(model::Term::NatType),
TypeParam::String => self.make_term(model::Term::StrType),
TypeParam::List { param } => {
let item_type = self.export_type_param(param);
let item_type = self.export_type_param(param, None);
self.make_term(model::Term::ListType { item_type })
}
TypeParam::Tuple { params } => {
let items = self.bump.alloc_slice_fill_iter(
params.iter().map(|param| self.export_type_param(param)),
params
.iter()
.map(|param| self.export_type_param(param, None)),
);
let types = self.make_term(model::Term::List { items, tail: None });
self.make_term(model::Term::ApplyFull {
Expand Down
107 changes: 72 additions & 35 deletions hugr-core/src/import.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,8 @@ struct Context<'a> {
/// A map from `NodeId` to the imported `Node`.
nodes: FxHashMap<model::NodeId, Node>,

/// The types of the local variables that are currently in scope.
local_variables: FxIndexMap<&'a str, model::TermId>,
/// The local variables that are currently in scope.
local_variables: FxIndexMap<&'a str, LocalVar>,

custom_name_cache: FxHashMap<&'a str, (ExtensionId, SmolStr)>,
}
Expand Down Expand Up @@ -155,20 +155,20 @@ impl<'a> Context<'a> {
.ok_or_else(|| model::ModelError::RegionNotFound(region_id).into())
}

/// Looks up a [`LocalRef`] within the current scope and returns its index and type.
/// Looks up a [`LocalRef`] within the current scope.
fn resolve_local_ref(
&self,
local_ref: &model::LocalRef,
) -> Result<(usize, model::TermId), ImportError> {
) -> Result<(usize, LocalVar), ImportError> {
let term = match local_ref {
model::LocalRef::Index(_, index) => self
.local_variables
.get_index(*index as usize)
.map(|(_, term)| (*index as usize, *term)),
.map(|(_, v)| (*index as usize, *v)),
model::LocalRef::Named(name) => self
.local_variables
.get_full(name)
.map(|(index, _, term)| (index, *term)),
.map(|(index, _, v)| (index, *v)),
};

term.ok_or_else(|| model::ModelError::InvalidLocal(local_ref.to_string()).into())
Expand Down Expand Up @@ -898,41 +898,49 @@ impl<'a> Context<'a> {
self.with_local_socpe(|ctx| {
let mut imported_params = Vec::with_capacity(decl.params.len());

for param in decl.params {
// TODO: `PolyFuncType` should be able to handle constraints
// and distinguish between implicit and explicit parameters.
match param {
model::Param::Implicit { name, r#type } => {
imported_params.push(ctx.import_type_param(*r#type)?);
ctx.local_variables.insert(name, *r#type);
}
model::Param::Explicit { name, r#type } => {
imported_params.push(ctx.import_type_param(*r#type)?);
ctx.local_variables.insert(name, *r#type);
}
model::Param::Constraint { constraint: _ } => {
return Err(error_unsupported!("constraints"));
ctx.local_variables.extend(
decl.params
.iter()
.map(|param| (param.name, LocalVar::new(param.r#type))),
);

for constraint in decl.constraints {
match ctx.get_term(*constraint)? {
model::Term::NonLinearConstraint { term } => {
let model::Term::Var(var) = ctx.get_term(*term)? else {
return Err(error_unsupported!(
"constraint on term that is not a variable"
));
};

let var = ctx.resolve_local_ref(var)?.0;
ctx.local_variables[var].bound = TypeBound::Copyable;
}
_ => return Err(error_unsupported!("constraint other than copy or discard")),
}
}

for (index, param) in decl.params.iter().enumerate() {
// NOTE: `PolyFuncType` only has explicit type parameters at present.
let bound = ctx.local_variables[index].bound;
imported_params.push(ctx.import_type_param(param.r#type, bound)?);
}

let body = ctx.import_func_type::<RV>(decl.signature)?;
in_scope(ctx, PolyFuncTypeBase::new(imported_params, body))
})
}

/// Import a [`TypeParam`] from a term that represents a static type.
fn import_type_param(&mut self, term_id: model::TermId) -> Result<TypeParam, ImportError> {
fn import_type_param(
&mut self,
term_id: model::TermId,
bound: TypeBound,
) -> Result<TypeParam, ImportError> {
match self.get_term(term_id)? {
model::Term::Wildcard => Err(error_uninferred!("wildcard")),

model::Term::Type => {
// As part of the migration from `TypeBound`s to constraints, we pretend that all
// `TypeBound`s are copyable.
Ok(TypeParam::Type {
b: TypeBound::Copyable,
})
}
model::Term::Type => Ok(TypeParam::Type { b: bound }),

model::Term::StaticType => Err(error_unsupported!("`type` as `TypeParam`")),
model::Term::Constraint => Err(error_unsupported!("`constraint` as `TypeParam`")),
Expand All @@ -944,7 +952,9 @@ impl<'a> Context<'a> {
model::Term::FuncType { .. } => Err(error_unsupported!("`(fn ...)` as `TypeParam`")),

model::Term::ListType { item_type } => {
let param = Box::new(self.import_type_param(*item_type)?);
// At present `hugr-model` has no way to express that the item
// type of a list must be copyable. Therefore we import it as `Any`.
let param = Box::new(self.import_type_param(*item_type, TypeBound::Any)?);
Ok(TypeParam::List { param })
}

Expand All @@ -958,15 +968,18 @@ impl<'a> Context<'a> {
| model::Term::List { .. }
| model::Term::ExtSet { .. }
| model::Term::Adt { .. }
| model::Term::Control { .. } => Err(model::ModelError::TypeError(term_id).into()),
| model::Term::Control { .. }
| model::Term::NonLinearConstraint { .. } => {
Err(model::ModelError::TypeError(term_id).into())
}

model::Term::ControlType => {
Err(error_unsupported!("type of control types as `TypeParam`"))
}
}
}

/// Import a `TypeArg` froma term that represents a static type or value.
/// Import a `TypeArg` from a term that represents a static type or value.
fn import_type_arg(&mut self, term_id: model::TermId) -> Result<TypeArg, ImportError> {
match self.get_term(term_id)? {
model::Term::Wildcard => Err(error_uninferred!("wildcard")),
Expand All @@ -975,8 +988,8 @@ impl<'a> Context<'a> {
}

model::Term::Var(var) => {
let (index, var_type) = self.resolve_local_ref(var)?;
let decl = self.import_type_param(var_type)?;
let (index, var) = self.resolve_local_ref(var)?;
let decl = self.import_type_param(var.r#type, var.bound)?;
Ok(TypeArg::new_var_use(index, decl))
}

Expand Down Expand Up @@ -1014,7 +1027,10 @@ impl<'a> Context<'a> {

model::Term::FuncType { .. }
| model::Term::Adt { .. }
| model::Term::Control { .. } => Err(model::ModelError::TypeError(term_id).into()),
| model::Term::Control { .. }
| model::Term::NonLinearConstraint { .. } => {
Err(model::ModelError::TypeError(term_id).into())
}
}
}

Expand Down Expand Up @@ -1115,7 +1131,10 @@ impl<'a> Context<'a> {
| model::Term::List { .. }
| model::Term::Control { .. }
| model::Term::ControlType
| model::Term::Nat(_) => Err(model::ModelError::TypeError(term_id).into()),
| model::Term::Nat(_)
| model::Term::NonLinearConstraint { .. } => {
Err(model::ModelError::TypeError(term_id).into())
}
}
}

Expand Down Expand Up @@ -1291,3 +1310,21 @@ impl<'a> Names<'a> {
Ok(Self { items })
}
}

/// Information about a local variable.
#[derive(Debug, Clone, Copy)]
struct LocalVar {
/// The type of the variable.
r#type: model::TermId,
/// The type bound of the variable.
bound: TypeBound,
}

impl LocalVar {
pub fn new(r#type: model::TermId) -> Self {
Self {
r#type,
bound: TypeBound::Any,
}
}
}
7 changes: 7 additions & 0 deletions hugr-core/tests/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,10 @@ pub fn test_roundtrip_params() {
"../../hugr-model/tests/fixtures/model-params.edn"
)));
}

#[test]
pub fn test_roundtrip_constraints() {
insta::assert_snapshot!(roundtrip(include_str!(
"../../hugr-model/tests/fixtures/model-constraints.edn"
)));
}
Loading

0 comments on commit 6bd094f

Please sign in to comment.