Skip to content

Commit

Permalink
feat!: MakeOpDef has new extension method. (#1266)
Browse files Browse the repository at this point in the history
Closes #1241 

Used to ensure `try_from_name` only succeeds when the extension is
correct

BREAKING CHANGE:
- `MakeOpDef` trait has new required `extension` method.
- `try_from_name` takes the OpDef extension and checks it
  • Loading branch information
ss2165 authored Jul 8, 2024
1 parent cdc3739 commit 75192f7
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 6 deletions.
22 changes: 20 additions & 2 deletions hugr-core/src/extension/simple_op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ pub enum OpLoadError {
NotMember(String),
#[error("Type args invalid: {0}.")]
InvalidArgs(#[from] SignatureError),
#[error("OpDef belongs to extension {0}, expected {1}.")]
WrongExtension(ExtensionId, ExtensionId),
}

impl<T> NamedOp for T
Expand All @@ -51,6 +53,9 @@ pub trait MakeOpDef: NamedOp {
/// Return the signature (polymorphic function type) of the operation.
fn signature(&self) -> SignatureFunc;

/// The ID of the extension this operation is defined in.
fn extension(&self) -> ExtensionId;

/// Description of the operation. By default, the same as `self.name()`.
fn description(&self) -> String {
self.name().to_string()
Expand Down Expand Up @@ -138,11 +143,20 @@ impl<T: MakeOpDef> MakeExtensionOp for T {

/// Load an [MakeOpDef] from its name.
/// See [strum_macros::EnumString].
pub fn try_from_name<T>(name: &OpNameRef) -> Result<T, OpLoadError>
pub fn try_from_name<T>(name: &OpNameRef, def_extension: &ExtensionId) -> Result<T, OpLoadError>
where
T: std::str::FromStr + MakeOpDef,
{
T::from_str(name).map_err(|_| OpLoadError::NotMember(name.to_string()))
let op = T::from_str(name).map_err(|_| OpLoadError::NotMember(name.to_string()))?;
let expected_extension = op.extension();
if def_extension != &expected_extension {
return Err(OpLoadError::WrongExtension(
def_extension.clone(),
expected_extension,
));
}

Ok(op)
}

/// Wrap an [MakeExtensionOp] with an extension registry to allow type computation.
Expand Down Expand Up @@ -245,6 +259,10 @@ mod test {
fn from_def(_op_def: &OpDef) -> Result<Self, OpLoadError> {
Ok(Self::Dumb)
}

fn extension(&self) -> ExtensionId {
EXT_ID.to_owned()
}
}
const_extension_ids! {
const EXT_ID: ExtensionId = "DummyExt";
Expand Down
6 changes: 5 additions & 1 deletion hugr-core/src/std_extensions/arithmetic/conversions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,11 @@ pub enum ConvertOpDef {

impl MakeOpDef for ConvertOpDef {
fn from_def(op_def: &OpDef) -> Result<Self, OpLoadError> {
crate::extension::simple_op::try_from_name(op_def.name())
crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension())
}

fn extension(&self) -> ExtensionId {
EXTENSION_ID.to_owned()
}

fn signature(&self) -> SignatureFunc {
Expand Down
6 changes: 5 additions & 1 deletion hugr-core/src/std_extensions/arithmetic/float_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,11 @@ pub enum FloatOps {

impl MakeOpDef for FloatOps {
fn from_def(op_def: &OpDef) -> Result<Self, OpLoadError> {
crate::extension::simple_op::try_from_name(op_def.name())
crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension())
}

fn extension(&self) -> ExtensionId {
EXTENSION_ID.to_owned()
}

fn signature(&self) -> SignatureFunc {
Expand Down
6 changes: 5 additions & 1 deletion hugr-core/src/std_extensions/arithmetic/int_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,11 @@ pub enum IntOpDef {

impl MakeOpDef for IntOpDef {
fn from_def(op_def: &OpDef) -> Result<Self, crate::extension::simple_op::OpLoadError> {
crate::extension::simple_op::try_from_name(op_def.name())
crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension())
}

fn extension(&self) -> ExtensionId {
EXTENSION_ID.to_owned()
}

fn signature(&self) -> SignatureFunc {
Expand Down
10 changes: 9 additions & 1 deletion hugr-core/src/std_extensions/logic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,11 @@ impl MakeOpDef for NaryLogic {
}

fn from_def(op_def: &OpDef) -> Result<Self, OpLoadError> {
try_from_name(op_def.name())
try_from_name(op_def.name(), op_def.extension())
}

fn extension(&self) -> ExtensionId {
EXTENSION_ID.to_owned()
}

fn post_opdef(&self, def: &mut OpDef) {
Expand Down Expand Up @@ -127,6 +131,10 @@ impl MakeOpDef for NotOp {
}
}

fn extension(&self) -> ExtensionId {
EXTENSION_ID.to_owned()
}

fn signature(&self) -> SignatureFunc {
FunctionType::new_endo(type_row![BOOL_T]).into()
}
Expand Down

0 comments on commit 75192f7

Please sign in to comment.