From 649589c9e3f1fbd9cfff53a2adb8e1f9649fbe87 Mon Sep 17 00:00:00 2001 From: Mark Koch <48097969+mark-koch@users.noreply.github.com> Date: Thu, 21 Nov 2024 10:09:04 +0000 Subject: [PATCH] feat: Add array `repeat` and `scan` ops (#1633) Closes #1627 --- hugr-core/src/extension/prelude.rs | 1 + hugr-core/src/extension/prelude/array.rs | 293 ++++++++++++++++++- hugr-py/src/hugr/std/_json_defs/prelude.json | 183 ++++++++++++ specification/std_extensions/prelude.json | 183 ++++++++++++ 4 files changed, 659 insertions(+), 1 deletion(-) diff --git a/hugr-core/src/extension/prelude.rs b/hugr-core/src/extension/prelude.rs index ca338eae3..8e0eb0f4b 100644 --- a/hugr-core/src/extension/prelude.rs +++ b/hugr-core/src/extension/prelude.rs @@ -101,6 +101,7 @@ lazy_static! { NoopDef.add_to_extension(&mut prelude).unwrap(); LiftDef.add_to_extension(&mut prelude).unwrap(); array::ArrayOpDef::load_all_ops(&mut prelude).unwrap(); + array::ArrayScanDef.add_to_extension(&mut prelude).unwrap(); prelude }; /// An extension registry containing only the prelude diff --git a/hugr-core/src/extension/prelude/array.rs b/hugr-core/src/extension/prelude/array.rs index a15bf23cc..6013039d4 100644 --- a/hugr-core/src/extension/prelude/array.rs +++ b/hugr-core/src/extension/prelude/array.rs @@ -1,3 +1,6 @@ +use std::str::FromStr; + +use itertools::Itertools; use strum_macros::EnumIter; use strum_macros::EnumString; use strum_macros::IntoStaticStr; @@ -17,8 +20,10 @@ use crate::ops::ExtensionOp; use crate::ops::NamedOp; use crate::ops::OpName; use crate::type_row; +use crate::types::FuncTypeBase; use crate::types::FuncValueType; +use crate::types::RowVariable; use crate::types::TypeBound; use crate::types::Type; @@ -28,6 +33,7 @@ use crate::extension::SignatureError; use crate::types::PolyFuncTypeRV; use crate::types::type_param::TypeArg; +use crate::types::TypeRV; use crate::Extension; use super::PRELUDE_ID; @@ -46,6 +52,7 @@ pub enum ArrayOpDef { pop_left, pop_right, discard_empty, + repeat, } /// Static parameters for array operations. Includes array size. Type is part of the type scheme. @@ -118,6 +125,14 @@ impl ArrayOpDef { let standard_params = vec![TypeParam::max_nat(), TypeBound::Any.into()]; match self { + repeat => { + let func = + Type::new_function(FuncValueType::new(type_row![], elem_ty_var.clone())); + PolyFuncTypeRV::new( + standard_params, + FuncValueType::new(vec![func], array_ty.clone()), + ) + } get => { let params = vec![TypeParam::max_nat(), TypeBound::Copyable.into()]; let copy_elem_ty = Type::new_var_use(1, TypeBound::Copyable); @@ -179,6 +194,10 @@ impl MakeOpDef for ArrayOpDef { fn description(&self) -> String { match self { ArrayOpDef::new_array => "Create a new array from elements", + ArrayOpDef::repeat => { + "Creates a new array whose elements are initialised by calling \ + the given function n times" + } ArrayOpDef::get => "Get an element from an array", ArrayOpDef::set => "Set an element in an array", ArrayOpDef::swap => "Swap two elements in an array", @@ -246,7 +265,7 @@ impl MakeExtensionOp for ArrayOp { ); vec![ty_arg] } - new_array | pop_left | pop_right | get | set | swap => { + new_array | repeat | pop_left | pop_right | get | set | swap => { vec![TypeArg::BoundedNat { n: self.size }, ty_arg] } } @@ -312,6 +331,192 @@ pub fn new_array_op(element_ty: Type, size: u64) -> ExtensionOp { op.to_extension_op().unwrap() } +/// Name of the operation for the combined map/fold operation +pub const ARRAY_SCAN_OP_ID: OpName = OpName::new_inline("scan"); + +/// Definition of the array scan op. +#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)] +pub struct ArrayScanDef; + +impl NamedOp for ArrayScanDef { + fn name(&self) -> OpName { + ARRAY_SCAN_OP_ID + } +} + +impl FromStr for ArrayScanDef { + type Err = (); + + fn from_str(s: &str) -> Result { + if s == ArrayScanDef.name() { + Ok(Self) + } else { + Err(()) + } + } +} + +impl ArrayScanDef { + /// To avoid recursion when defining the extension, take the type definition as an argument. + fn signature_from_def(&self, array_def: &TypeDef) -> SignatureFunc { + // array, (T1, *A -> T2, *A), -> array, *A + let params = vec![ + TypeParam::max_nat(), + TypeBound::Any.into(), + TypeBound::Any.into(), + TypeParam::new_list(TypeBound::Any), + ]; + let n = TypeArg::new_var_use(0, TypeParam::max_nat()); + let t1 = Type::new_var_use(1, TypeBound::Any); + let t2 = Type::new_var_use(2, TypeBound::Any); + let s = TypeRV::new_row_var_use(3, TypeBound::Any); + PolyFuncTypeRV::new( + params, + FuncTypeBase::::new( + vec![ + instantiate(array_def, n.clone(), t1.clone()).into(), + Type::new_function(FuncTypeBase::::new( + vec![t1.into(), s.clone()], + vec![t2.clone().into(), s.clone()], + )) + .into(), + s.clone(), + ], + vec![instantiate(array_def, n, t2).into(), s], + ), + ) + .into() + } +} + +impl MakeOpDef for ArrayScanDef { + fn from_def(op_def: &OpDef) -> Result + where + Self: Sized, + { + crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension()) + } + + fn signature(&self) -> SignatureFunc { + self.signature_from_def(array_type_def()) + } + + fn extension(&self) -> ExtensionId { + PRELUDE_ID + } + + fn description(&self) -> String { + "A combination of map and foldl. Applies a function to each element \ + of the array with an accumulator that is passed through from start to \ + finish. Returns the resulting array and the final state of the \ + accumulator." + .into() + } + + /// Add an operation implemented as a [MakeOpDef], which can provide the data + /// required to define an [OpDef], to an extension. + // + // This method is re-defined here since we need to pass the array type def while + // computing the signature, to avoid recursive loops initializing the extension. + fn add_to_extension( + &self, + extension: &mut Extension, + ) -> Result<(), crate::extension::ExtensionBuildError> { + let sig = self.signature_from_def(extension.get_type(ARRAY_TYPE_NAME).unwrap()); + let def = extension.add_op(self.name(), self.description(), sig)?; + + self.post_opdef(def); + + Ok(()) + } +} + +/// Definition of the array scan op. +#[derive(Clone, Debug, PartialEq)] +pub struct ArrayScan { + /// The element type of the input array. + src_ty: Type, + /// The target element type of the output array. + tgt_ty: Type, + /// The accumulator types. + acc_tys: Vec, + /// Size of the array. + size: u64, +} + +impl ArrayScan { + fn new(src_ty: Type, tgt_ty: Type, acc_tys: Vec, size: u64) -> Self { + ArrayScan { + src_ty, + tgt_ty, + acc_tys, + size, + } + } +} + +impl NamedOp for ArrayScan { + fn name(&self) -> OpName { + ARRAY_SCAN_OP_ID + } +} + +impl MakeExtensionOp for ArrayScan { + fn from_extension_op(ext_op: &ExtensionOp) -> Result + where + Self: Sized, + { + let def = ArrayScanDef::from_def(ext_op.def())?; + def.instantiate(ext_op.args()) + } + + fn type_args(&self) -> Vec { + vec![ + TypeArg::BoundedNat { n: self.size }, + self.src_ty.clone().into(), + self.tgt_ty.clone().into(), + TypeArg::Sequence { + elems: self.acc_tys.clone().into_iter().map_into().collect(), + }, + ] + } +} + +impl MakeRegisteredOp for ArrayScan { + fn extension_id(&self) -> ExtensionId { + PRELUDE_ID + } + + fn registry<'s, 'r: 's>(&'s self) -> &'r crate::extension::ExtensionRegistry { + &PRELUDE_REGISTRY + } +} + +impl HasDef for ArrayScan { + type Def = ArrayScanDef; +} + +impl HasConcrete for ArrayScanDef { + type Concrete = ArrayScan; + + fn instantiate(&self, type_args: &[TypeArg]) -> Result { + match type_args { + [TypeArg::BoundedNat { n }, TypeArg::Type { ty: src_ty }, TypeArg::Type { ty: tgt_ty }, TypeArg::Sequence { elems: acc_tys }] => + { + let acc_tys: Result<_, OpLoadError> = acc_tys + .iter() + .map(|acc_ty| match acc_ty { + TypeArg::Type { ty } => Ok(ty.clone()), + _ => Err(SignatureError::InvalidTypeArgs.into()), + }) + .collect(); + Ok(ArrayScan::new(src_ty.clone(), tgt_ty.clone(), acc_tys?, *n)) + } + _ => Err(SignatureError::InvalidTypeArgs.into()), + } + } +} + #[cfg(test)] mod tests { use strum::IntoEnumIterator; @@ -320,6 +525,7 @@ mod tests { builder::{inout_sig, DFGBuilder, Dataflow, DataflowHugr}, extension::prelude::{BOOL_T, QB_T}, ops::{OpTrait, OpType}, + types::Signature, }; use super::*; @@ -459,4 +665,89 @@ mod tests { ) ); } + + #[test] + fn test_repeat() { + let size = 2; + let element_ty = QB_T; + let op = ArrayOpDef::repeat.to_concrete(element_ty.clone(), size); + + let optype: OpType = op.into(); + + let sig = optype.dataflow_signature().unwrap(); + + assert_eq!( + sig.io(), + ( + &vec![Type::new_function(Signature::new(vec![], vec![QB_T]))].into(), + &vec![array_type(size, element_ty.clone())].into(), + ) + ); + } + + #[test] + fn test_scan_def() { + let op = ArrayScan::new(BOOL_T, QB_T, vec![USIZE_T], 2); + let optype: OpType = op.clone().into(); + let new_op: ArrayScan = optype.cast().unwrap(); + assert_eq!(new_op, op); + } + + #[test] + fn test_scan_map() { + let size = 2; + let src_ty = QB_T; + let tgt_ty = BOOL_T; + + let op = ArrayScan::new(src_ty.clone(), tgt_ty.clone(), vec![], size); + let optype: OpType = op.into(); + let sig = optype.dataflow_signature().unwrap(); + + assert_eq!( + sig.io(), + ( + &vec![ + array_type(size, src_ty.clone()), + Type::new_function(Signature::new(vec![src_ty], vec![tgt_ty.clone()])) + ] + .into(), + &vec![array_type(size, tgt_ty)].into(), + ) + ); + } + + #[test] + fn test_scan_accs() { + let size = 2; + let src_ty = QB_T; + let tgt_ty = BOOL_T; + let acc_ty1 = USIZE_T; + let acc_ty2 = QB_T; + + let op = ArrayScan::new( + src_ty.clone(), + tgt_ty.clone(), + vec![acc_ty1.clone(), acc_ty2.clone()], + size, + ); + let optype: OpType = op.into(); + let sig = optype.dataflow_signature().unwrap(); + + assert_eq!( + sig.io(), + ( + &vec![ + array_type(size, src_ty.clone()), + Type::new_function(Signature::new( + vec![src_ty, acc_ty1.clone(), acc_ty2.clone()], + vec![tgt_ty.clone(), acc_ty1.clone(), acc_ty2.clone()] + )), + acc_ty1.clone(), + acc_ty2.clone() + ] + .into(), + &vec![array_type(size, tgt_ty), acc_ty1, acc_ty2].into(), + ) + ); + } } diff --git a/hugr-py/src/hugr/std/_json_defs/prelude.json b/hugr-py/src/hugr/std/_json_defs/prelude.json index 014ba3ede..b48692b39 100644 --- a/hugr-py/src/hugr/std/_json_defs/prelude.json +++ b/hugr-py/src/hugr/std/_json_defs/prelude.json @@ -418,6 +418,189 @@ }, "binary": false }, + "repeat": { + "extension": "prelude", + "name": "repeat", + "description": "Creates a new array whose elements are initialised by calling the given function n times", + "signature": { + "params": [ + { + "tp": "BoundedNat", + "bound": null + }, + { + "tp": "Type", + "b": "A" + } + ], + "body": { + "input": [ + { + "t": "G", + "input": [], + "output": [ + { + "t": "V", + "i": 1, + "b": "A" + } + ], + "extension_reqs": [] + } + ], + "output": [ + { + "t": "Opaque", + "extension": "prelude", + "id": "array", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": null + } + }, + { + "tya": "Type", + "ty": { + "t": "V", + "i": 1, + "b": "A" + } + } + ], + "bound": "A" + } + ], + "extension_reqs": [] + } + }, + "binary": false + }, + "scan": { + "extension": "prelude", + "name": "scan", + "description": "A combination of map and foldl. Applies a function to each element of the array with an accumulator that is passed through from start to finish. Returns the resulting array and the final state of the accumulator.", + "signature": { + "params": [ + { + "tp": "BoundedNat", + "bound": null + }, + { + "tp": "Type", + "b": "A" + }, + { + "tp": "Type", + "b": "A" + }, + { + "tp": "List", + "param": { + "tp": "Type", + "b": "A" + } + } + ], + "body": { + "input": [ + { + "t": "Opaque", + "extension": "prelude", + "id": "array", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": null + } + }, + { + "tya": "Type", + "ty": { + "t": "V", + "i": 1, + "b": "A" + } + } + ], + "bound": "A" + }, + { + "t": "G", + "input": [ + { + "t": "V", + "i": 1, + "b": "A" + }, + { + "t": "R", + "i": 3, + "b": "A" + } + ], + "output": [ + { + "t": "V", + "i": 2, + "b": "A" + }, + { + "t": "R", + "i": 3, + "b": "A" + } + ], + "extension_reqs": [] + }, + { + "t": "R", + "i": 3, + "b": "A" + } + ], + "output": [ + { + "t": "Opaque", + "extension": "prelude", + "id": "array", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": null + } + }, + { + "tya": "Type", + "ty": { + "t": "V", + "i": 2, + "b": "A" + } + } + ], + "bound": "A" + }, + { + "t": "R", + "i": 3, + "b": "A" + } + ], + "extension_reqs": [] + } + }, + "binary": false + }, "set": { "extension": "prelude", "name": "set", diff --git a/specification/std_extensions/prelude.json b/specification/std_extensions/prelude.json index 014ba3ede..b48692b39 100644 --- a/specification/std_extensions/prelude.json +++ b/specification/std_extensions/prelude.json @@ -418,6 +418,189 @@ }, "binary": false }, + "repeat": { + "extension": "prelude", + "name": "repeat", + "description": "Creates a new array whose elements are initialised by calling the given function n times", + "signature": { + "params": [ + { + "tp": "BoundedNat", + "bound": null + }, + { + "tp": "Type", + "b": "A" + } + ], + "body": { + "input": [ + { + "t": "G", + "input": [], + "output": [ + { + "t": "V", + "i": 1, + "b": "A" + } + ], + "extension_reqs": [] + } + ], + "output": [ + { + "t": "Opaque", + "extension": "prelude", + "id": "array", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": null + } + }, + { + "tya": "Type", + "ty": { + "t": "V", + "i": 1, + "b": "A" + } + } + ], + "bound": "A" + } + ], + "extension_reqs": [] + } + }, + "binary": false + }, + "scan": { + "extension": "prelude", + "name": "scan", + "description": "A combination of map and foldl. Applies a function to each element of the array with an accumulator that is passed through from start to finish. Returns the resulting array and the final state of the accumulator.", + "signature": { + "params": [ + { + "tp": "BoundedNat", + "bound": null + }, + { + "tp": "Type", + "b": "A" + }, + { + "tp": "Type", + "b": "A" + }, + { + "tp": "List", + "param": { + "tp": "Type", + "b": "A" + } + } + ], + "body": { + "input": [ + { + "t": "Opaque", + "extension": "prelude", + "id": "array", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": null + } + }, + { + "tya": "Type", + "ty": { + "t": "V", + "i": 1, + "b": "A" + } + } + ], + "bound": "A" + }, + { + "t": "G", + "input": [ + { + "t": "V", + "i": 1, + "b": "A" + }, + { + "t": "R", + "i": 3, + "b": "A" + } + ], + "output": [ + { + "t": "V", + "i": 2, + "b": "A" + }, + { + "t": "R", + "i": 3, + "b": "A" + } + ], + "extension_reqs": [] + }, + { + "t": "R", + "i": 3, + "b": "A" + } + ], + "output": [ + { + "t": "Opaque", + "extension": "prelude", + "id": "array", + "args": [ + { + "tya": "Variable", + "idx": 0, + "cached_decl": { + "tp": "BoundedNat", + "bound": null + } + }, + { + "tya": "Type", + "ty": { + "t": "V", + "i": 2, + "b": "A" + } + } + ], + "bound": "A" + }, + { + "t": "R", + "i": 3, + "b": "A" + } + ], + "extension_reqs": [] + } + }, + "binary": false + }, "set": { "extension": "prelude", "name": "set",